diff --git a/train/.gitkeep b/train/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/train/Dockerfile b/train/Dockerfile new file mode 100644 index 0000000..2456ea4 --- /dev/null +++ b/train/Dockerfile @@ -0,0 +1,29 @@ +# Use NVIDIA base image +FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 + +# Set the working directory +WORKDIR /app + + +# Set environment variable for GitPython +ENV GIT_PYTHON_REFRESH=quiet + +# Install Python and pip +RUN apt-get update && apt-get install -y --fix-broken && \ + apt-get install -y \ + python3 \ + python3-pip \ + python3-distutils \ + python3-setuptools \ + python3-wheel && \ + rm -rf /var/lib/apt/lists/* + +# Copy and install Python dependencies +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Copy the rest of the application +COPY . . + +# Specify the entry point +CMD ["python3", "train.py", "with", "config_params_docker.json"] diff --git a/train/LICENSE b/train/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/train/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/train/README.md b/train/README.md new file mode 100644 index 0000000..7c69a10 --- /dev/null +++ b/train/README.md @@ -0,0 +1,90 @@ +# Pixelwise Segmentation +> Pixelwise segmentation for document images + +## Introduction +This repository contains the source code for training an encoder model for document image segmentation. + +## Installation +Either clone the repository via `git clone https://github.com/qurator-spk/sbb_pixelwise_segmentation.git` or download and unpack the [ZIP](https://github.com/qurator-spk/sbb_pixelwise_segmentation/archive/master.zip). + +### Pretrained encoder +Download our pretrained weights and add them to a ``pretrained_model`` folder: +https://qurator-data.de/sbb_pixelwise_segmentation/pretrained_encoder/ + +### Helpful tools +* [`pagexml2img`](https://github.com/qurator-spk/page2img) +> Tool to extract 2-D or 3-D RGB images from PAGE-XML data. In the former case, the output will be 1 2-D image array which each class has filled with a pixel value. In the case of a 3-D RGB image, +each class will be defined with a RGB value and beside images, a text file of classes will also be produced. +* [`cocoSegmentationToPng`](https://github.com/nightrome/cocostuffapi/blob/17acf33aef3c6cc2d6aca46dcf084266c2778cf0/PythonAPI/pycocotools/cocostuffhelper.py#L130) +> Convert COCO GT or results for a single image to a segmentation map and write it to disk. +* [`ocrd-segment-extract-pages`](https://github.com/OCR-D/ocrd_segment/blob/master/ocrd_segment/extract_pages.py) +> Extract region classes and their colours in mask (pseg) images. Allows the color map as free dict parameter, and comes with a default that mimics PageViewer's coloring for quick debugging; it also warns when regions do overlap. + +## Usage + +### Train +To train a model, run: ``python train.py with config_params.json`` + +### Train using Docker + +#### Build the Docker image + + ```bash + docker build -t model-training . + ``` +#### Run Docker image + ```bash + docker run --gpus all -v /host/path/to/entry_point_dir:/entry_point_dir model-training + ``` + +### Ground truth format +Lables for each pixel are identified by a number. So if you have a +binary case, ``n_classes`` should be set to ``2`` and labels should +be ``0`` and ``1`` for each class and pixel. + +In the case of multiclass, just set ``n_classes`` to the number of classes +you have and the try to produce the labels by pixels set from ``0 , 1 ,2 .., n_classes-1``. +The labels format should be png. +Our lables are 3 channel png images but only information of first channel is used. +If you have an image label with height and width of 10, for a binary case the first channel should look like this: + + Label: [ [1, 0, 0, 1, 1, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ..., + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] ] + + This means that you have an image by `10*10*3` and `pixel[0,0]` belongs + to class `1` and `pixel[0,1]` belongs to class `0`. + + A small sample of training data for binarization experiment can be found here, [Training data sample](https://qurator-data.de/~vahid.rezanezhad/binarization_training_data_sample/), which contains images and lables folders. + +### Training , evaluation and output +The train and evaluation folders should contain subfolders of images and labels. +The output folder should be an empty folder where the output model will be written to. + +### Parameter configuration +* patches: If you want to break input images into smaller patches (input size of the model) you need to set this parameter to ``true``. In the case that the model should see the image once, like page extraction, patches should be set to ``false``. +* n_batch: Number of batches at each iteration. +* n_classes: Number of classes. In the case of binary classification this should be 2. +* n_epochs: Number of epochs. +* input_height: This indicates the height of model's input. +* input_width: This indicates the width of model's input. +* weight_decay: Weight decay of l2 regularization of model layers. +* augmentation: If you want to apply any kind of augmentation this parameter should first set to ``true``. +* flip_aug: If ``true``, different types of filp will be applied on image. Type of flips is given with "flip_index" in train.py file. +* blur_aug: If ``true``, different types of blurring will be applied on image. Type of blurrings is given with "blur_k" in train.py file. +* scaling: If ``true``, scaling will be applied on image. Scale of scaling is given with "scales" in train.py file. +* rotation_not_90: If ``true``, rotation (not 90 degree) will be applied on image. Rotation angles are given with "thetha" in train.py file. +* rotation: If ``true``, 90 degree rotation will be applied on image. +* binarization: If ``true``,Otsu thresholding will be applied to augment the input data with binarized images. +* scaling_bluring: If ``true``, combination of scaling and blurring will be applied on image. +* scaling_binarization: If ``true``, combination of scaling and binarization will be applied on image. +* scaling_flip: If ``true``, combination of scaling and flip will be applied on image. +* continue_training: If ``true``, it means that you have already trained a model and you would like to continue the training. So it is needed to provide the dir of trained model with "dir_of_start_model" and index for naming the models. For example if you have already trained for 3 epochs then your last index is 2 and if you want to continue from model_1.h5, you can set "index_start" to 3 to start naming model with index 3. +* weighted_loss: If ``true``, this means that you want to apply weighted categorical_crossentropy as loss fucntion. Be carefull if you set to ``true``the parameter "is_loss_soft_dice" should be ``false`` +* data_is_provided: If you have already provided the input data you can set this to ``true``. Be sure that the train and eval data are in "dir_output". Since when once we provide training data we resize and augment them and then we write them in sub-directories train and eval in "dir_output". +* dir_train: This is the directory of "images" and "labels" (dir_train should include two subdirectories with names of images and labels ) for raw images and labels. Namely they are not prepared (not resized and not augmented) yet for training the model. When we run this tool these raw data will be transformed to suitable size needed for the model and they will be written in "dir_output" in train and eval directories. Each of train and eval include "images" and "labels" sub-directories. + +#### Additional documentation +Please check the [wiki](https://github.com/qurator-spk/sbb_pixelwise_segmentation/wiki). diff --git a/train/__init__.py b/train/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/train/build_model_load_pretrained_weights_and_save.py b/train/build_model_load_pretrained_weights_and_save.py new file mode 100644 index 0000000..125611e --- /dev/null +++ b/train/build_model_load_pretrained_weights_and_save.py @@ -0,0 +1,29 @@ +import os +import sys +import tensorflow as tf +import warnings +from tensorflow.keras.optimizers import * +from sacred import Experiment +from models import * +from utils import * +from metrics import * + + +def configuration(): + gpu_options = tf.compat.v1.GPUOptions(allow_growth=True) + session = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options)) + + +if __name__ == '__main__': + n_classes = 2 + input_height = 224 + input_width = 448 + weight_decay = 1e-6 + pretraining = False + dir_of_weights = 'model_bin_sbb_ens.h5' + + # configuration() + + model = resnet50_unet(n_classes, input_height, input_width, weight_decay, pretraining) + model.load_weights(dir_of_weights) + model.save('./name_in_another_python_version.h5') diff --git a/train/config_params.json b/train/config_params.json new file mode 100644 index 0000000..1db8026 --- /dev/null +++ b/train/config_params.json @@ -0,0 +1,58 @@ +{ + "backbone_type" : "transformer", + "task": "segmentation", + "n_classes" : 2, + "n_epochs" : 0, + "input_height" : 448, + "input_width" : 448, + "weight_decay" : 1e-6, + "n_batch" : 1, + "learning_rate": 1e-4, + "patches" : false, + "pretraining" : true, + "augmentation" : true, + "flip_aug" : false, + "blur_aug" : false, + "scaling" : false, + "adding_rgb_background": true, + "adding_rgb_foreground": true, + "add_red_textlines": false, + "channels_shuffling": false, + "degrading": false, + "brightening": false, + "binarization" : true, + "scaling_bluring" : false, + "scaling_binarization" : false, + "scaling_flip" : false, + "rotation": false, + "rotation_not_90": false, + "transformer_num_patches_xy": [56, 56], + "transformer_patchsize_x": 4, + "transformer_patchsize_y": 4, + "transformer_projection_dim": 64, + "transformer_mlp_head_units": [128, 64], + "transformer_layers": 1, + "transformer_num_heads": 1, + "transformer_cnn_first": false, + "blur_k" : ["blur","guass","median"], + "scales" : [0.6, 0.7, 0.8, 0.9], + "brightness" : [1.3, 1.5, 1.7, 2], + "degrade_scales" : [0.2, 0.4], + "flip_index" : [0, 1, -1], + "shuffle_indexes" : [ [0,2,1], [1,2,0], [1,0,2] , [2,1,0]], + "thetha" : [5, -5], + "number_of_backgrounds_per_image": 2, + "continue_training": false, + "index_start" : 0, + "dir_of_start_model" : " ", + "weighted_loss": false, + "is_loss_soft_dice": false, + "data_is_provided": false, + "dir_train": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/train_new", + "dir_eval": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/eval_new", + "dir_output": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/output_new", + "dir_rgb_backgrounds": "/home/vahid/Documents/1_2_test_eynollah/set_rgb_background", + "dir_rgb_foregrounds": "/home/vahid/Documents/1_2_test_eynollah/out_set_rgb_foreground", + "dir_img_bin": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/train_new/images_bin" + +} diff --git a/train/config_params_docker.json b/train/config_params_docker.json new file mode 100644 index 0000000..45f87d3 --- /dev/null +++ b/train/config_params_docker.json @@ -0,0 +1,54 @@ +{ + "backbone_type" : "nontransformer", + "task": "segmentation", + "n_classes" : 3, + "n_epochs" : 1, + "input_height" : 672, + "input_width" : 448, + "weight_decay" : 1e-6, + "n_batch" : 4, + "learning_rate": 1e-4, + "patches" : false, + "pretraining" : true, + "augmentation" : false, + "flip_aug" : false, + "blur_aug" : true, + "scaling" : true, + "adding_rgb_background": false, + "adding_rgb_foreground": false, + "add_red_textlines": false, + "channels_shuffling": true, + "degrading": true, + "brightening": true, + "binarization" : false, + "scaling_bluring" : false, + "scaling_binarization" : false, + "scaling_flip" : false, + "rotation": false, + "rotation_not_90": true, + "transformer_num_patches_xy": [14, 21], + "transformer_patchsize_x": 1, + "transformer_patchsize_y": 1, + "transformer_projection_dim": 64, + "transformer_mlp_head_units": [128, 64], + "transformer_layers": 1, + "transformer_num_heads": 1, + "transformer_cnn_first": true, + "blur_k" : ["blur","gauss","median"], + "scales" : [0.6, 0.7, 0.8, 0.9], + "brightness" : [1.3, 1.5, 1.7, 2], + "degrade_scales" : [0.2, 0.4], + "flip_index" : [0, 1, -1], + "shuffle_indexes" : [ [0,2,1], [1,2,0], [1,0,2] , [2,1,0]], + "thetha" : [5, -5], + "number_of_backgrounds_per_image": 2, + "continue_training": false, + "index_start" : 0, + "dir_of_start_model" : " ", + "weighted_loss": false, + "is_loss_soft_dice": true, + "data_is_provided": false, + "dir_train": "/entry_point_dir/train", + "dir_eval": "/entry_point_dir/eval", + "dir_output": "/entry_point_dir/output" +} diff --git a/train/custom_config_page2label.json b/train/custom_config_page2label.json new file mode 100644 index 0000000..9116ce3 --- /dev/null +++ b/train/custom_config_page2label.json @@ -0,0 +1,8 @@ +{ +"use_case": "textline", +"textregions":{ "rest_as_paragraph": 1, "header":2 , "heading":2 , "marginalia":3 }, +"imageregion":4, +"separatorregion":5, +"graphicregions" :{"rest_as_decoration":6}, +"columns_width":{"1":1000, "2":1300, "3":1600, "4":2000, "5":2300, "6":2500} +} diff --git a/train/generate_gt_for_training.py b/train/generate_gt_for_training.py new file mode 100644 index 0000000..388fced --- /dev/null +++ b/train/generate_gt_for_training.py @@ -0,0 +1,567 @@ +import click +import json +from gt_gen_utils import * +from tqdm import tqdm +from pathlib import Path +from PIL import Image, ImageDraw, ImageFont + +@click.group() +def main(): + pass + +@main.command() +@click.option( + "--dir_xml", + "-dx", + help="directory of GT page-xml files", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--dir_images", + "-di", + help="directory of org images. If print space cropping or scaling is needed for labels it would be great to provide the original images to apply the same function on them. So if -ps is not set true or in config files no columns_width key is given this argumnet can be ignored. File stems in this directory should be the same as those in dir_xml.", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--dir_out_images", + "-doi", + help="directory where the output org images after undergoing a process (like print space cropping or scaling) will be written.", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--dir_out", + "-do", + help="directory where ground truth label images would be written", + type=click.Path(exists=True, file_okay=False), +) + +@click.option( + "--config", + "-cfg", + help="config file of prefered layout or use case.", + type=click.Path(exists=True, dir_okay=False), +) + +@click.option( + "--type_output", + "-to", + help="this defines how output should be. A 2d image array or a 3d image array encoded with RGB color. Just pass 2d or 3d. The file will be saved one directory up. 2D image array is 3d but only information of one channel would be enough since all channels have the same values.", +) +@click.option( + "--printspace", + "-ps", + is_flag=True, + help="if this parameter set to true, generated labels and in the case of provided org images cropping will be imposed and cropped labels and images will be written in output directories.", +) + +def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, dir_out_images): + if config: + with open(config) as f: + config_params = json.load(f) + else: + print("passed") + config_params = None + gt_list = get_content_of_dir(dir_xml) + get_images_of_ground_truth(gt_list,dir_xml,dir_out,type_output, config, config_params, printspace, dir_images, dir_out_images) + +@main.command() +@click.option( + "--dir_imgs", + "-dis", + help="directory of images with high resolution.", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--dir_out_images", + "-dois", + help="directory where degraded images will be written.", + type=click.Path(exists=True, file_okay=False), +) + +@click.option( + "--dir_out_labels", + "-dols", + help="directory where original images will be written as labels.", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--scales", + "-scs", + help="json dictionary where the scales are written.", + type=click.Path(exists=True, dir_okay=False), +) +def image_enhancement(dir_imgs, dir_out_images, dir_out_labels, scales): + ls_imgs = os.listdir(dir_imgs) + with open(scales) as f: + scale_dict = json.load(f) + ls_scales = scale_dict['scales'] + + for img in tqdm(ls_imgs): + img_name = img.split('.')[0] + img_type = img.split('.')[1] + image = cv2.imread(os.path.join(dir_imgs, img)) + for i, scale in enumerate(ls_scales): + height_sc = int(image.shape[0]*scale) + width_sc = int(image.shape[1]*scale) + + image_down_scaled = resize_image(image, height_sc, width_sc) + image_back_to_org_scale = resize_image(image_down_scaled, image.shape[0], image.shape[1]) + + cv2.imwrite(os.path.join(dir_out_images, img_name+'_'+str(i)+'.'+img_type), image_back_to_org_scale) + cv2.imwrite(os.path.join(dir_out_labels, img_name+'_'+str(i)+'.'+img_type), image) + + +@main.command() +@click.option( + "--dir_xml", + "-dx", + help="directory of GT page-xml files", + type=click.Path(exists=True, file_okay=False), +) + +@click.option( + "--dir_out_modal_image", + "-domi", + help="directory where ground truth images would be written", + type=click.Path(exists=True, file_okay=False), +) + +@click.option( + "--dir_out_classes", + "-docl", + help="directory where ground truth classes would be written", + type=click.Path(exists=True, file_okay=False), +) + +@click.option( + "--input_height", + "-ih", + help="input height", +) +@click.option( + "--input_width", + "-iw", + help="input width", +) +@click.option( + "--min_area_size", + "-min", + help="min area size of regions considered for reading order training.", +) + +@click.option( + "--min_area_early", + "-min_early", + help="If you have already generated a training dataset using a specific minimum area value and now wish to create a dataset with a smaller minimum area value, you can avoid regenerating the previous dataset by providing the earlier minimum area value. This will ensure that only the missing data is generated.", +) + +def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, input_height, input_width, min_area_size, min_area_early): + xml_files_ind = os.listdir(dir_xml) + xml_files_ind = [ind_xml for ind_xml in xml_files_ind if ind_xml.endswith('.xml')] + input_height = int(input_height) + input_width = int(input_width) + min_area = float(min_area_size) + if min_area_early: + min_area_early = float(min_area_early) + + + indexer_start= 0#55166 + max_area = 1 + #min_area = 0.0001 + + for ind_xml in tqdm(xml_files_ind): + indexer = 0 + #print(ind_xml) + #print('########################') + xml_file = os.path.join(dir_xml,ind_xml ) + f_name = ind_xml.split('.')[0] + _, _, _, file_name, id_paragraph, id_header,co_text_paragraph,co_text_header,tot_region_ref,x_len, y_len,index_tot_regions,img_poly = read_xml(xml_file) + + id_all_text = id_paragraph + id_header + co_text_all = co_text_paragraph + co_text_header + + + _, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, _ = find_new_features_of_contours(co_text_header) + + img_header_and_sep = np.zeros((y_len,x_len), dtype='uint8') + + for j in range(len(cy_main)): + img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12,int(x_min_main[j]):int(x_max_main[j]) ] = 1 + + + texts_corr_order_index = [index_tot_regions[tot_region_ref.index(i)] for i in id_all_text ] + texts_corr_order_index_int = [int(x) for x in texts_corr_order_index] + + + co_text_all, texts_corr_order_index_int, regions_ar_less_than_early_min = filter_contours_area_of_image(img_poly, co_text_all, texts_corr_order_index_int, max_area, min_area, min_area_early) + + + arg_array = np.array(range(len(texts_corr_order_index_int))) + + labels_con = np.zeros((y_len,x_len,len(arg_array)),dtype='uint8') + for i in range(len(co_text_all)): + img_label = np.zeros((y_len,x_len,3),dtype='uint8') + img_label=cv2.fillPoly(img_label, pts =[co_text_all[i]], color=(1,1,1)) + + img_label[:,:,0][img_poly[:,:,0]==5] = 2 + img_label[:,:,0][img_header_and_sep[:,:]==1] = 3 + + labels_con[:,:,i] = img_label[:,:,0] + + labels_con = resize_image(labels_con, input_height, input_width) + img_poly = resize_image(img_poly, input_height, input_width) + + + for i in range(len(texts_corr_order_index_int)): + for j in range(len(texts_corr_order_index_int)): + if i!=j: + if regions_ar_less_than_early_min: + if regions_ar_less_than_early_min[i]==1: + input_multi_visual_modal = np.zeros((input_height,input_width,3)).astype(np.int8) + final_f_name = f_name+'_'+str(indexer+indexer_start) + order_class_condition = texts_corr_order_index_int[i]-texts_corr_order_index_int[j] + if order_class_condition<0: + class_type = 1 + else: + class_type = 0 + + input_multi_visual_modal[:,:,0] = labels_con[:,:,i] + input_multi_visual_modal[:,:,1] = img_poly[:,:,0] + input_multi_visual_modal[:,:,2] = labels_con[:,:,j] + + np.save(os.path.join(dir_out_classes,final_f_name+'_missed.npy' ), class_type) + + cv2.imwrite(os.path.join(dir_out_modal_image,final_f_name+'_missed.png' ), input_multi_visual_modal) + indexer = indexer+1 + + else: + input_multi_visual_modal = np.zeros((input_height,input_width,3)).astype(np.int8) + final_f_name = f_name+'_'+str(indexer+indexer_start) + order_class_condition = texts_corr_order_index_int[i]-texts_corr_order_index_int[j] + if order_class_condition<0: + class_type = 1 + else: + class_type = 0 + + input_multi_visual_modal[:,:,0] = labels_con[:,:,i] + input_multi_visual_modal[:,:,1] = img_poly[:,:,0] + input_multi_visual_modal[:,:,2] = labels_con[:,:,j] + + np.save(os.path.join(dir_out_classes,final_f_name+'.npy' ), class_type) + + cv2.imwrite(os.path.join(dir_out_modal_image,final_f_name+'.png' ), input_multi_visual_modal) + indexer = indexer+1 + + +@main.command() +@click.option( + "--xml_file", + "-xml", + help="xml filename", + type=click.Path(exists=True, dir_okay=False), +) +@click.option( + "--dir_xml", + "-dx", + help="directory of GT page-xml files", + type=click.Path(exists=True, file_okay=False), +) + +@click.option( + "--dir_out", + "-o", + help="directory where plots will be written", + type=click.Path(exists=True, file_okay=False), +) + +@click.option( + "--dir_imgs", + "-di", + help="directory where the overlayed plots will be written", ) + +def visualize_reading_order(xml_file, dir_xml, dir_out, dir_imgs): + assert xml_file or dir_xml, "A single xml file -xml or a dir of xml files -dx is required not both of them" + + if dir_xml: + xml_files_ind = os.listdir(dir_xml) + xml_files_ind = [ind_xml for ind_xml in xml_files_ind if ind_xml.endswith('.xml')] + else: + xml_files_ind = [xml_file] + + indexer_start= 0#55166 + #min_area = 0.0001 + + for ind_xml in tqdm(xml_files_ind): + indexer = 0 + #print(ind_xml) + #print('########################') + #xml_file = os.path.join(dir_xml,ind_xml ) + + if dir_xml: + xml_file = os.path.join(dir_xml,ind_xml ) + f_name = Path(ind_xml).stem + else: + xml_file = os.path.join(ind_xml ) + f_name = Path(ind_xml).stem + print(f_name, 'f_name') + + #f_name = ind_xml.split('.')[0] + _, _, _, file_name, id_paragraph, id_header,co_text_paragraph,co_text_header,tot_region_ref,x_len, y_len,index_tot_regions,img_poly = read_xml(xml_file) + + id_all_text = id_paragraph + id_header + co_text_all = co_text_paragraph + co_text_header + + + cx_main, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, _ = find_new_features_of_contours(co_text_all) + + texts_corr_order_index = [int(index_tot_regions[tot_region_ref.index(i)]) for i in id_all_text ] + #texts_corr_order_index_int = [int(x) for x in texts_corr_order_index] + + + #cx_ordered = np.array(cx_main)[np.array(texts_corr_order_index)] + #cx_ordered = cx_ordered.astype(np.int32) + + cx_ordered = [int(val) for (_, val) in sorted(zip(texts_corr_order_index, cx_main), key=lambda x: \ + x[0], reverse=False)] + #cx_ordered = cx_ordered.astype(np.int32) + + cy_ordered = [int(val) for (_, val) in sorted(zip(texts_corr_order_index, cy_main), key=lambda x: \ + x[0], reverse=False)] + #cy_ordered = cy_ordered.astype(np.int32) + + + color = (0, 0, 255) + thickness = 20 + if dir_imgs: + layout = np.zeros( (y_len,x_len,3) ) + layout = cv2.fillPoly(layout, pts =co_text_all, color=(1,1,1)) + + img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name) + img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format)) + + overlayed = overlay_layout_on_image(layout, img, cx_ordered, cy_ordered, color, thickness) + cv2.imwrite(os.path.join(dir_out, f_name+'.png'), overlayed) + + else: + img = np.zeros( (y_len,x_len,3) ) + img = cv2.fillPoly(img, pts =co_text_all, color=(255,0,0)) + for i in range(len(cx_ordered)-1): + start_point = (int(cx_ordered[i]), int(cy_ordered[i])) + end_point = (int(cx_ordered[i+1]), int(cy_ordered[i+1])) + img = cv2.arrowedLine(img, start_point, end_point, + color, thickness, tipLength = 0.03) + + cv2.imwrite(os.path.join(dir_out, f_name+'.png'), img) + + +@main.command() +@click.option( + "--xml_file", + "-xml", + help="xml filename", + type=click.Path(exists=True, dir_okay=False), +) +@click.option( + "--dir_xml", + "-dx", + help="directory of GT page-xml files", + type=click.Path(exists=True, file_okay=False), +) + +@click.option( + "--dir_out", + "-o", + help="directory where plots will be written", + type=click.Path(exists=True, file_okay=False), +) + +@click.option( + "--dir_imgs", + "-di", + help="directory of images where textline segmentation will be overlayed", ) + +def visualize_textline_segmentation(xml_file, dir_xml, dir_out, dir_imgs): + assert xml_file or dir_xml, "A single xml file -xml or a dir of xml files -dx is required not both of them" + if dir_xml: + xml_files_ind = os.listdir(dir_xml) + xml_files_ind = [ind_xml for ind_xml in xml_files_ind if ind_xml.endswith('.xml')] + else: + xml_files_ind = [xml_file] + + for ind_xml in tqdm(xml_files_ind): + indexer = 0 + #print(ind_xml) + #print('########################') + xml_file = os.path.join(dir_xml,ind_xml ) + f_name = Path(ind_xml).stem + + img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name) + img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format)) + + co_tetxlines, y_len, x_len = get_textline_contours_for_visualization(xml_file) + + added_image = visualize_image_from_contours(co_tetxlines, img) + + cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image) + + + +@main.command() +@click.option( + "--xml_file", + "-xml", + help="xml filename", + type=click.Path(exists=True, dir_okay=False), +) +@click.option( + "--dir_xml", + "-dx", + help="directory of GT page-xml files", + type=click.Path(exists=True, file_okay=False), +) + +@click.option( + "--dir_out", + "-o", + help="directory where plots will be written", + type=click.Path(exists=True, file_okay=False), +) + +@click.option( + "--dir_imgs", + "-di", + help="directory of images where textline segmentation will be overlayed", ) + +def visualize_layout_segmentation(xml_file, dir_xml, dir_out, dir_imgs): + assert xml_file or dir_xml, "A single xml file -xml or a dir of xml files -dx is required not both of them" + if dir_xml: + xml_files_ind = os.listdir(dir_xml) + xml_files_ind = [ind_xml for ind_xml in xml_files_ind if ind_xml.endswith('.xml')] + else: + xml_files_ind = [xml_file] + + for ind_xml in tqdm(xml_files_ind): + indexer = 0 + #print(ind_xml) + #print('########################') + if dir_xml: + xml_file = os.path.join(dir_xml,ind_xml ) + f_name = Path(ind_xml).stem + else: + xml_file = os.path.join(ind_xml ) + f_name = Path(ind_xml).stem + print(f_name, 'f_name') + + img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name) + img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format)) + + co_text, co_graphic, co_sep, co_img, co_table, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file) + + + added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], co_table, img) + + cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image) + + + + +@main.command() +@click.option( + "--xml_file", + "-xml", + help="xml filename", + type=click.Path(exists=True, dir_okay=False), +) +@click.option( + "--dir_xml", + "-dx", + help="directory of GT page-xml files", + type=click.Path(exists=True, file_okay=False), +) + +@click.option( + "--dir_out", + "-o", + help="directory where plots will be written", + type=click.Path(exists=True, file_okay=False), +) + + +def visualize_ocr_text(xml_file, dir_xml, dir_out): + assert xml_file or dir_xml, "A single xml file -xml or a dir of xml files -dx is required not both of them" + if dir_xml: + xml_files_ind = os.listdir(dir_xml) + xml_files_ind = [ind_xml for ind_xml in xml_files_ind if ind_xml.endswith('.xml')] + else: + xml_files_ind = [xml_file] + + font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! + font = ImageFont.truetype(font_path, 40) + + for ind_xml in tqdm(xml_files_ind): + indexer = 0 + #print(ind_xml) + #print('########################') + if dir_xml: + xml_file = os.path.join(dir_xml,ind_xml ) + f_name = Path(ind_xml).stem + else: + xml_file = os.path.join(ind_xml ) + f_name = Path(ind_xml).stem + print(f_name, 'f_name') + + co_tetxlines, y_len, x_len, ocr_texts = get_textline_contours_and_ocr_text(xml_file) + + total_bb_coordinates = [] + + image_text = Image.new("RGB", (x_len, y_len), "white") + draw = ImageDraw.Draw(image_text) + + + + for index, cnt in enumerate(co_tetxlines): + x,y,w,h = cv2.boundingRect(cnt) + #total_bb_coordinates.append([x,y,w,h]) + + #fit_text_single_line + + #x_bb = bb_ind[0] + #y_bb = bb_ind[1] + #w_bb = bb_ind[2] + #h_bb = bb_ind[3] + if ocr_texts[index]: + + + is_vertical = h > 2*w # Check orientation + font = fit_text_single_line(draw, ocr_texts[index], font_path, w, int(h*0.4) ) + + if is_vertical: + + vertical_font = fit_text_single_line(draw, ocr_texts[index], font_path, h, int(w * 0.8)) + + text_img = Image.new("RGBA", (h, w), (255, 255, 255, 0)) # Note: dimensions are swapped + text_draw = ImageDraw.Draw(text_img) + text_draw.text((0, 0), ocr_texts[index], font=vertical_font, fill="black") + + # Rotate text image by 90 degrees + rotated_text = text_img.rotate(90, expand=1) + + # Calculate paste position (centered in bbox) + paste_x = x + (w - rotated_text.width) // 2 + paste_y = y + (h - rotated_text.height) // 2 + + image_text.paste(rotated_text, (paste_x, paste_y), rotated_text) # Use rotated image as mask + else: + text_bbox = draw.textbbox((0, 0), ocr_texts[index], font=font) + text_width = text_bbox[2] - text_bbox[0] + text_height = text_bbox[3] - text_bbox[1] + + text_x = x + (w - text_width) // 2 # Center horizontally + text_y = y + (h - text_height) // 2 # Center vertically + + # Draw the text + draw.text((text_x, text_y), ocr_texts[index], fill="black", font=font) + image_text.save(os.path.join(dir_out, f_name+'.png')) + +if __name__ == "__main__": + main() diff --git a/train/gt_gen_utils.py b/train/gt_gen_utils.py new file mode 100644 index 0000000..38d48ca --- /dev/null +++ b/train/gt_gen_utils.py @@ -0,0 +1,1838 @@ +import click +import sys +import os +import numpy as np +import warnings +import xml.etree.ElementTree as ET +from tqdm import tqdm +import cv2 +from shapely import geometry +from pathlib import Path +import matplotlib.pyplot as plt +from PIL import Image, ImageDraw, ImageFont + + +KERNEL = np.ones((5, 5), np.uint8) + +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + +def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_image, co_marginal, co_table, img): + alpha = 0.5 + + blank_image = np.ones( (img.shape[:]), dtype=np.uint8) * 255 + + col_header = (173, 216, 230) + col_drop = (0, 191, 255) + boundary_color = (143, 216, 200)#(0, 0, 255) # Dark gray for the boundary + col_par = (0, 0, 139) # Lighter gray for the filled area + col_image = (0, 100, 0) + col_sep = (255, 0, 0) + col_marginal = (106, 90, 205) + col_table = (0, 90, 205) + + if len(co_image)>0: + cv2.drawContours(blank_image, co_image, -1, col_image, thickness=cv2.FILLED) # Fill the contour + + if len(co_sep)>0: + cv2.drawContours(blank_image, co_sep, -1, col_sep, thickness=cv2.FILLED) # Fill the contour + + + if len(co_header)>0: + cv2.drawContours(blank_image, co_header, -1, col_header, thickness=cv2.FILLED) # Fill the contour + + if len(co_par)>0: + cv2.drawContours(blank_image, co_par, -1, col_par, thickness=cv2.FILLED) # Fill the contour + + cv2.drawContours(blank_image, co_par, -1, boundary_color, thickness=1) # Draw the boundary + + if len(co_drop)>0: + cv2.drawContours(blank_image, co_drop, -1, col_drop, thickness=cv2.FILLED) # Fill the contour + + if len(co_marginal)>0: + cv2.drawContours(blank_image, co_marginal, -1, col_marginal, thickness=cv2.FILLED) # Fill the contour + + if len(co_table)>0: + cv2.drawContours(blank_image, co_table, -1, col_table, thickness=cv2.FILLED) # Fill the contour + + img_final =cv2.cvtColor(blank_image, cv2.COLOR_BGR2RGB) + + added_image = cv2.addWeighted(img,alpha,img_final,1- alpha,0) + return added_image + + +def visualize_image_from_contours(contours, img): + alpha = 0.5 + + blank_image = np.ones( (img.shape[:]), dtype=np.uint8) * 255 + + boundary_color = (0, 0, 255) # Dark gray for the boundary + fill_color = (173, 216, 230) # Lighter gray for the filled area + + cv2.drawContours(blank_image, contours, -1, fill_color, thickness=cv2.FILLED) # Fill the contour + cv2.drawContours(blank_image, contours, -1, boundary_color, thickness=1) # Draw the boundary + + img_final =cv2.cvtColor(blank_image, cv2.COLOR_BGR2RGB) + + added_image = cv2.addWeighted(img,alpha,img_final,1- alpha,0) + return added_image + +def visualize_model_output(prediction, img, task): + if task == "binarization": + prediction = prediction * -1 + prediction = prediction + 1 + added_image = prediction * 255 + layout_only = None + else: + unique_classes = np.unique(prediction[:,:,0]) + rgb_colors = {'0' : [255, 255, 255], + '1' : [255, 0, 0], + '2' : [255, 125, 0], + '3' : [255, 0, 125], + '4' : [125, 125, 125], + '5' : [125, 125, 0], + '6' : [0, 125, 255], + '7' : [0, 125, 0], + '8' : [125, 125, 125], + '9' : [0, 125, 255], + '10' : [125, 0, 125], + '11' : [0, 255, 0], + '12' : [0, 0, 255], + '13' : [0, 255, 255], + '14' : [255, 125, 125], + '15' : [255, 0, 255]} + + layout_only = np.zeros(prediction.shape) + + for unq_class in unique_classes: + rgb_class_unique = rgb_colors[str(int(unq_class))] + layout_only[:,:,0][prediction[:,:,0]==unq_class] = rgb_class_unique[0] + layout_only[:,:,1][prediction[:,:,0]==unq_class] = rgb_class_unique[1] + layout_only[:,:,2][prediction[:,:,0]==unq_class] = rgb_class_unique[2] + + + + img = resize_image(img, layout_only.shape[0], layout_only.shape[1]) + + layout_only = layout_only.astype(np.int32) + img = img.astype(np.int32) + + + + added_image = cv2.addWeighted(img,0.5,layout_only,0.1,0) + + return added_image, layout_only + +def get_content_of_dir(dir_in): + """ + Listing all ground truth page xml files. All files are needed to have xml format. + """ + + gt_all=os.listdir(dir_in) + gt_list = [file for file in gt_all if os.path.splitext(file)[1] == '.xml'] + return gt_list + +def return_parent_contours(contours, hierarchy): + contours_parent = [contours[i] for i in range(len(contours)) if hierarchy[0][i][3] == -1] + return contours_parent +def filter_contours_area_of_image_tables(image, contours, hierarchy, max_area, min_area): + found_polygons_early = list() + + jv = 0 + for c in contours: + if len(np.shape(c)) == 3: + c = c[0] + elif len(np.shape(c)) == 2: + pass + #c = c[0] + if len(c) < 3: # A polygon cannot have less than 3 points + continue + + c_e = [point for point in c] + polygon = geometry.Polygon(c_e) + # area = cv2.contourArea(c) + area = polygon.area + # Check that polygon has area greater than minimal area + if area >= min_area * np.prod(image.shape[:2]) and area <= max_area * np.prod(image.shape[:2]): # and hierarchy[0][jv][3]==-1 : + found_polygons_early.append(np.array([[point] for point in polygon.exterior.coords], dtype=np.int32)) + jv += 1 + return found_polygons_early + +def filter_contours_area_of_image(image, contours, order_index, max_area, min_area, min_early=None): + found_polygons_early = list() + order_index_filtered = list() + regions_ar_less_than_early_min = list() + #jv = 0 + for jv, c in enumerate(contours): + if len(np.shape(c)) == 3: + c = c[0] + elif len(np.shape(c)) == 2: + pass + if len(c) < 3: # A polygon cannot have less than 3 points + continue + c_e = [point for point in c] + polygon = geometry.Polygon(c_e) + area = polygon.area + if area >= min_area * np.prod(image.shape[:2]) and area <= max_area * np.prod(image.shape[:2]): # and hierarchy[0][jv][3]==-1 : + found_polygons_early.append(np.array([[point] for point in polygon.exterior.coords], dtype=np.uint)) + order_index_filtered.append(order_index[jv]) + if min_early: + if area < min_early * np.prod(image.shape[:2]) and area <= max_area * np.prod(image.shape[:2]): # and hierarchy[0][jv][3]==-1 : + regions_ar_less_than_early_min.append(1) + else: + regions_ar_less_than_early_min.append(0) + else: + regions_ar_less_than_early_min = None + + #jv += 1 + return found_polygons_early, order_index_filtered, regions_ar_less_than_early_min + +def return_contours_of_interested_region(region_pre_p, pixel, min_area=0.0002): + + # pixels of images are identified by 5 + if len(region_pre_p.shape) == 3: + cnts_images = (region_pre_p[:, :, 0] == pixel) * 1 + else: + cnts_images = (region_pre_p[:, :] == pixel) * 1 + cnts_images = cnts_images.astype(np.uint8) + cnts_images = np.repeat(cnts_images[:, :, np.newaxis], 3, axis=2) + imgray = cv2.cvtColor(cnts_images, cv2.COLOR_BGR2GRAY) + ret, thresh = cv2.threshold(imgray, 0, 255, 0) + + contours_imgs, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + + #print(len(contours_imgs), hierarchy) + + contours_imgs = return_parent_contours(contours_imgs, hierarchy) + + #print(len(contours_imgs), "iki") + #contours_imgs = filter_contours_area_of_image_tables(thresh, contours_imgs, hierarchy, max_area=1, min_area=min_area) + + return contours_imgs +def update_region_contours(co_text, img_boundary, erosion_rate, dilation_rate, y_len, x_len, dilation_early=None, erosion_early=None): + co_text_eroded = [] + for con in co_text: + img_boundary_in = np.zeros( (y_len,x_len) ) + img_boundary_in = cv2.fillPoly(img_boundary_in, pts=[con], color=(1, 1, 1)) + + if dilation_early: + img_boundary_in = cv2.dilate(img_boundary_in[:,:], KERNEL, iterations=dilation_early) + + if erosion_early: + img_boundary_in = cv2.erode(img_boundary_in[:,:], KERNEL, iterations=erosion_early) + + #img_boundary_in = cv2.erode(img_boundary_in[:,:], KERNEL, iterations=7)#asiatica + if erosion_rate > 0: + img_boundary_in = cv2.erode(img_boundary_in[:,:], KERNEL, iterations=erosion_rate) + + pixel = 1 + min_size = 0 + + img_boundary_in = img_boundary_in.astype("uint8") + + con_eroded = return_contours_of_interested_region(img_boundary_in,pixel, min_size ) + + try: + co_text_eroded.append(con_eroded[0]) + except: + co_text_eroded.append(con) + + + img_boundary_in_dilated = cv2.dilate(img_boundary_in[:,:], KERNEL, iterations=dilation_rate) + #img_boundary_in_dilated = cv2.dilate(img_boundary_in[:,:], KERNEL, iterations=5) + + boundary = img_boundary_in_dilated[:,:] - img_boundary_in[:,:] + + img_boundary[:,:][boundary[:,:]==1] =1 + return co_text_eroded, img_boundary + +def get_textline_contours_for_visualization(xml_file): + tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8')) + root1=tree1.getroot() + alltags=[elem.tag for elem in root1.iter()] + link=alltags[0].split('}')[0]+'}' + + + + for jj in root1.iter(link+'Page'): + y_len=int(jj.attrib['imageHeight']) + x_len=int(jj.attrib['imageWidth']) + + region_tags = np.unique([x for x in alltags if x.endswith('TextLine')]) + tag_endings = ['}TextLine','}textline'] + co_use_case = [] + + for tag in region_tags: + if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]): + for nn in root1.iter(tag): + c_t_in = [] + sumi = 0 + for vv in nn.iter(): + if vv.tag == link + 'Coords': + coords = bool(vv.attrib) + if coords: + p_h = vv.attrib['points'].split(' ') + c_t_in.append( + np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])) + break + else: + pass + + if vv.tag == link + 'Point': + c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))]) + sumi += 1 + elif vv.tag != link + 'Point' and sumi >= 1: + break + co_use_case.append(np.array(c_t_in)) + return co_use_case, y_len, x_len + + +def get_textline_contours_and_ocr_text(xml_file): + tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8')) + root1=tree1.getroot() + alltags=[elem.tag for elem in root1.iter()] + link=alltags[0].split('}')[0]+'}' + + + + for jj in root1.iter(link+'Page'): + y_len=int(jj.attrib['imageHeight']) + x_len=int(jj.attrib['imageWidth']) + + region_tags = np.unique([x for x in alltags if x.endswith('TextLine')]) + tag_endings = ['}TextLine','}textline'] + co_use_case = [] + ocr_textlines = [] + + for tag in region_tags: + if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]): + for nn in root1.iter(tag): + c_t_in = [] + ocr_text_in = [''] + sumi = 0 + for vv in nn.iter(): + if vv.tag == link + 'Coords': + for childtest2 in nn: + if childtest2.tag.endswith("TextEquiv"): + for child_uc in childtest2: + if child_uc.tag.endswith("Unicode"): + text = child_uc.text + ocr_text_in[0]= text + + coords = bool(vv.attrib) + if coords: + p_h = vv.attrib['points'].split(' ') + c_t_in.append( + np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])) + break + else: + pass + + + + if vv.tag == link + 'Point': + c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))]) + sumi += 1 + elif vv.tag != link + 'Point' and sumi >= 1: + break + + + co_use_case.append(np.array(c_t_in)) + ocr_textlines.append(ocr_text_in[0]) + return co_use_case, y_len, x_len, ocr_textlines + +def fit_text_single_line(draw, text, font_path, max_width, max_height): + initial_font_size = 50 + font_size = initial_font_size + while font_size > 10: # Minimum font size + font = ImageFont.truetype(font_path, font_size) + text_bbox = draw.textbbox((0, 0), text, font=font) # Get text bounding box + text_width = text_bbox[2] - text_bbox[0] + text_height = text_bbox[3] - text_bbox[1] + + if text_width <= max_width and text_height <= max_height: + return font # Return the best-fitting font + + font_size -= 2 # Reduce font size and retry + + return ImageFont.truetype(font_path, 10) # Smallest font fallback + +def get_layout_contours_for_visualization(xml_file): + tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8')) + root1=tree1.getroot() + alltags=[elem.tag for elem in root1.iter()] + link=alltags[0].split('}')[0]+'}' + + + + for jj in root1.iter(link+'Page'): + y_len=int(jj.attrib['imageHeight']) + x_len=int(jj.attrib['imageWidth']) + + region_tags=np.unique([x for x in alltags if x.endswith('Region')]) + co_text = {'drop-capital':[], "footnote":[], "footnote-continued":[], "heading":[], "signature-mark":[], "header":[], "catch-word":[], "page-number":[], "marginalia":[], "paragraph":[]} + all_defined_textregion_types = list(co_text.keys()) + co_graphic = {"handwritten-annotation":[], "decoration":[], "stamp":[], "signature":[]} + all_defined_graphic_types = list(co_graphic.keys()) + co_sep=[] + co_img=[] + co_table=[] + co_noise=[] + + types_text = [] + types_graphic = [] + + for tag in region_tags: + if tag.endswith('}TextRegion') or tag.endswith('}Textregion'): + for nn in root1.iter(tag): + c_t_in = {'drop-capital':[], "footnote":[], "footnote-continued":[], "heading":[], "signature-mark":[], "header":[], "catch-word":[], "page-number":[], "marginalia":[], "paragraph":[]} + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + + if "rest_as_paragraph" in types_text: + types_text_without_paragraph = [element for element in types_text if element!='rest_as_paragraph' and element!='paragraph'] + if len(types_text_without_paragraph) == 0: + if "type" in nn.attrib: + c_t_in['paragraph'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + else: + c_t_in['paragraph'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + elif len(types_text_without_paragraph) >= 1: + if "type" in nn.attrib: + if nn.attrib['type'] in types_text_without_paragraph: + c_t_in[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + else: + c_t_in['paragraph'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + else: + c_t_in['paragraph'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + else: + if "type" in nn.attrib: + if nn.attrib['type'] in all_defined_textregion_types: + c_t_in[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + else: + c_t_in['paragraph'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + break + else: + pass + + + if vv.tag==link+'Point': + if "rest_as_paragraph" in types_text: + types_text_without_paragraph = [element for element in types_text if element!='rest_as_paragraph' and element!='paragraph'] + if len(types_text_without_paragraph) == 0: + if "type" in nn.attrib: + c_t_in['paragraph'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + elif len(types_text_without_paragraph) >= 1: + if "type" in nn.attrib: + if nn.attrib['type'] in types_text_without_paragraph: + c_t_in[nn.attrib['type']].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + else: + c_t_in['paragraph'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + + else: + if "type" in nn.attrib: + if nn.attrib['type'] in all_defined_textregion_types: + c_t_in[nn.attrib['type']].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + + + elif vv.tag!=link+'Point' and sumi>=1: + break + + for element_text in list(c_t_in.keys()): + if len(c_t_in[element_text])>0: + co_text[element_text].append(np.array(c_t_in[element_text])) + + + if tag.endswith('}GraphicRegion') or tag.endswith('}graphicregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in_graphic = {"handwritten-annotation":[], "decoration":[], "stamp":[], "signature":[]} + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + + if "rest_as_decoration" in types_graphic: + types_graphic_without_decoration = [element for element in types_graphic if element!='rest_as_decoration' and element!='decoration'] + if len(types_graphic_without_decoration) == 0: + if "type" in nn.attrib: + c_t_in_graphic['decoration'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + elif len(types_graphic_without_decoration) >= 1: + if "type" in nn.attrib: + if nn.attrib['type'] in types_graphic_without_decoration: + c_t_in_graphic[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + else: + c_t_in_graphic['decoration'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + else: + if "type" in nn.attrib: + if nn.attrib['type'] in all_defined_graphic_types: + c_t_in_graphic[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + break + else: + pass + + + if vv.tag==link+'Point': + if "rest_as_decoration" in types_graphic: + types_graphic_without_decoration = [element for element in types_graphic if element!='rest_as_decoration' and element!='decoration'] + if len(types_graphic_without_decoration) == 0: + if "type" in nn.attrib: + c_t_in_graphic['decoration'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + elif len(types_graphic_without_decoration) >= 1: + if "type" in nn.attrib: + if nn.attrib['type'] in types_graphic_without_decoration: + c_t_in_graphic[nn.attrib['type']].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + else: + c_t_in_graphic['decoration'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + + else: + if "type" in nn.attrib: + if nn.attrib['type'] in all_defined_graphic_types: + c_t_in_graphic[nn.attrib['type']].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + + elif vv.tag!=link+'Point' and sumi>=1: + break + + for element_graphic in list(c_t_in_graphic.keys()): + if len(c_t_in_graphic[element_graphic])>0: + co_graphic[element_graphic].append(np.array(c_t_in_graphic[element_graphic])) + + + if tag.endswith('}ImageRegion') or tag.endswith('}imageregion'): + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + + elif vv.tag!=link+'Point' and sumi>=1: + break + co_img.append(np.array(c_t_in)) + + + if tag.endswith('}SeparatorRegion') or tag.endswith('}separatorregion'): + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + + elif vv.tag!=link+'Point' and sumi>=1: + break + co_sep.append(np.array(c_t_in)) + + + if tag.endswith('}TableRegion') or tag.endswith('}tableregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_table.append(np.array(c_t_in)) + + + if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_noise.append(np.array(c_t_in)) + return co_text, co_graphic, co_sep, co_img, co_table, co_noise, y_len, x_len + +def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_file, config_params, printspace, dir_images, dir_out_images): + """ + Reading the page xml files and write the ground truth images into given output directory. + """ + ## to do: add footnote to text regions + + if dir_images: + ls_org_imgs = os.listdir(dir_images) + ls_org_imgs_stem = [os.path.splitext(item)[0] for item in ls_org_imgs] + for index in tqdm(range(len(gt_list))): + #try: + print(gt_list[index]) + tree1 = ET.parse(dir_in+'/'+gt_list[index], parser = ET.XMLParser(encoding='utf-8')) + root1=tree1.getroot() + alltags=[elem.tag for elem in root1.iter()] + link=alltags[0].split('}')[0]+'}' + + + + for jj in root1.iter(link+'Page'): + y_len=int(jj.attrib['imageHeight']) + x_len=int(jj.attrib['imageWidth']) + + if 'columns_width' in list(config_params.keys()): + columns_width_dict = config_params['columns_width'] + metadata_element = root1.find(link+'Metadata') + comment_is_sub_element = False + for child in metadata_element: + tag2 = child.tag + if tag2.endswith('}Comments') or tag2.endswith('}comments'): + text_comments = child.text + num_col = int(text_comments.split('num_col')[1]) + comment_is_sub_element = True + if not comment_is_sub_element: + num_col = None + + if num_col: + x_new = columns_width_dict[str(num_col)] + y_new = int ( x_new * (y_len / float(x_len)) ) + + if printspace or "printspace_as_class_in_layout" in list(config_params.keys()): + region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace') or x.endswith('Border')]) + co_use_case = [] + + for tag in region_tags: + tag_endings = ['}PrintSpace','}Border'] + + if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]): + for nn in root1.iter(tag): + c_t_in = [] + sumi = 0 + for vv in nn.iter(): + # check the format of coords + if vv.tag == link + 'Coords': + coords = bool(vv.attrib) + if coords: + p_h = vv.attrib['points'].split(' ') + c_t_in.append( + np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])) + break + else: + pass + + if vv.tag == link + 'Point': + c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))]) + sumi += 1 + elif vv.tag != link + 'Point' and sumi >= 1: + break + co_use_case.append(np.array(c_t_in)) + + img = np.zeros((y_len, x_len, 3)) + + img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1)) + + img_poly = img_poly.astype(np.uint8) + + imgray = cv2.cvtColor(img_poly, cv2.COLOR_BGR2GRAY) + _, thresh = cv2.threshold(imgray, 0, 255, 0) + + contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + + cnt_size = np.array([cv2.contourArea(contours[j]) for j in range(len(contours))]) + + cnt = contours[np.argmax(cnt_size)] + + x, y, w, h = cv2.boundingRect(cnt) + bb_xywh = [x, y, w, h] + + + if config_file and (config_params['use_case']=='textline' or config_params['use_case']=='word' or config_params['use_case']=='glyph' or config_params['use_case']=='printspace'): + keys = list(config_params.keys()) + if "artificial_class_label" in keys: + artificial_class_rgb_color = (255,255,0) + artificial_class_label = config_params['artificial_class_label'] + + textline_rgb_color = (255, 0, 0) + + if config_params['use_case']=='textline': + region_tags = np.unique([x for x in alltags if x.endswith('TextLine')]) + elif config_params['use_case']=='word': + region_tags = np.unique([x for x in alltags if x.endswith('Word')]) + elif config_params['use_case']=='glyph': + region_tags = np.unique([x for x in alltags if x.endswith('Glyph')]) + elif config_params['use_case']=='printspace': + region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace')]) + + co_use_case = [] + + for tag in region_tags: + if config_params['use_case']=='textline': + tag_endings = ['}TextLine','}textline'] + elif config_params['use_case']=='word': + tag_endings = ['}Word','}word'] + elif config_params['use_case']=='glyph': + tag_endings = ['}Glyph','}glyph'] + elif config_params['use_case']=='printspace': + tag_endings = ['}PrintSpace','}printspace'] + + if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]): + for nn in root1.iter(tag): + c_t_in = [] + sumi = 0 + for vv in nn.iter(): + # check the format of coords + if vv.tag == link + 'Coords': + coords = bool(vv.attrib) + if coords: + p_h = vv.attrib['points'].split(' ') + c_t_in.append( + np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])) + break + else: + pass + + if vv.tag == link + 'Point': + c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))]) + sumi += 1 + elif vv.tag != link + 'Point' and sumi >= 1: + break + co_use_case.append(np.array(c_t_in)) + + + if "artificial_class_label" in keys: + img_boundary = np.zeros((y_len, x_len)) + erosion_rate = 0#1 + dilation_rate = 2 + dilation_early = 0 + erosion_early = 2 + co_use_case, img_boundary = update_region_contours(co_use_case, img_boundary, erosion_rate, dilation_rate, y_len, x_len, dilation_early=dilation_early, erosion_early=erosion_early) + + + img = np.zeros((y_len, x_len, 3)) + if output_type == '2d': + img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1)) + if "artificial_class_label" in keys: + img_mask = np.copy(img_poly) + ##img_poly[:,:][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=1)] = artificial_class_label + img_poly[:,:][img_boundary[:,:]==1] = artificial_class_label + elif output_type == '3d': + img_poly = cv2.fillPoly(img, pts=co_use_case, color=textline_rgb_color) + if "artificial_class_label" in keys: + img_mask = np.copy(img_poly) + img_poly[:,:,0][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[0] + img_poly[:,:,1][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[1] + img_poly[:,:,2][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[2] + + + if printspace and config_params['use_case']!='printspace': + img_poly = img_poly[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :] + + + if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace': + img_poly = resize_image(img_poly, y_new, x_new) + + try: + xml_file_stem = os.path.splitext(gt_list[index])[0] + cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly) + except: + xml_file_stem = os.path.splitext(gt_list[index])[0] + cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly) + + if dir_images: + org_image_name = ls_org_imgs[ls_org_imgs_stem.index(xml_file_stem)] + img_org = cv2.imread(os.path.join(dir_images, org_image_name)) + + if printspace and config_params['use_case']!='printspace': + img_org = img_org[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :] + + if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace': + img_org = resize_image(img_org, y_new, x_new) + + cv2.imwrite(os.path.join(dir_out_images, org_image_name), img_org) + + + if config_file and config_params['use_case']=='layout': + keys = list(config_params.keys()) + + if "artificial_class_on_boundary" in keys: + elements_with_artificial_class = list(config_params['artificial_class_on_boundary']) + artificial_class_rgb_color = (255,255,0) + artificial_class_label = config_params['artificial_class_label'] + #values = config_params.values() + + if "printspace_as_class_in_layout" in list(config_params.keys()): + printspace_class_rgb_color = (125,125,255) + printspace_class_label = config_params['printspace_as_class_in_layout'] + + if 'textregions' in keys: + types_text_dict = config_params['textregions'] + types_text = list(types_text_dict.keys()) + types_text_label = list(types_text_dict.values()) + if 'graphicregions' in keys: + types_graphic_dict = config_params['graphicregions'] + types_graphic = list(types_graphic_dict.keys()) + types_graphic_label = list(types_graphic_dict.values()) + + + labels_rgb_color = [ (0,0,0), (255,0,0), (255,125,0), (255,0,125), (125,255,125), (125,125,0), (0,125,255), (0,125,0), (125,125,125), (255,0,255), (125,0,125), (0,255,0),(0,0,255), (0,255,255), (255,125,125), (0,125,125), (0,255,125), (255,125,255), (125,255,0)] + + + region_tags=np.unique([x for x in alltags if x.endswith('Region')]) + co_text = {'drop-capital':[], "footnote":[], "footnote-continued":[], "heading":[], "signature-mark":[], "header":[], "catch-word":[], "page-number":[], "marginalia":[], "paragraph":[]} + all_defined_textregion_types = list(co_text.keys()) + co_graphic = {"handwritten-annotation":[], "decoration":[], "stamp":[], "signature":[]} + all_defined_graphic_types = list(co_graphic.keys()) + co_sep=[] + co_img=[] + co_table=[] + co_noise=[] + + for tag in region_tags: + if 'textregions' in keys: + if tag.endswith('}TextRegion') or tag.endswith('}Textregion'): + for nn in root1.iter(tag): + c_t_in = {'drop-capital':[], "footnote":[], "footnote-continued":[], "heading":[], "signature-mark":[], "header":[], "catch-word":[], "page-number":[], "marginalia":[], "paragraph":[]} + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + + if "rest_as_paragraph" in types_text: + types_text_without_paragraph = [element for element in types_text if element!='rest_as_paragraph' and element!='paragraph'] + if len(types_text_without_paragraph) == 0: + if "type" in nn.attrib: + c_t_in['paragraph'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + elif len(types_text_without_paragraph) >= 1: + if "type" in nn.attrib: + if nn.attrib['type'] in types_text_without_paragraph: + c_t_in[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + else: + c_t_in['paragraph'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + else: + if "type" in nn.attrib: + if nn.attrib['type'] in all_defined_textregion_types: + c_t_in[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + break + else: + pass + + + if vv.tag==link+'Point': + if "rest_as_paragraph" in types_text: + types_text_without_paragraph = [element for element in types_text if element!='rest_as_paragraph' and element!='paragraph'] + if len(types_text_without_paragraph) == 0: + if "type" in nn.attrib: + c_t_in['paragraph'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + elif len(types_text_without_paragraph) >= 1: + if "type" in nn.attrib: + if nn.attrib['type'] in types_text_without_paragraph: + c_t_in[nn.attrib['type']].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + else: + c_t_in['paragraph'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + + else: + if "type" in nn.attrib: + if nn.attrib['type'] in all_defined_textregion_types: + c_t_in[nn.attrib['type']].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + + + elif vv.tag!=link+'Point' and sumi>=1: + break + + for element_text in list(c_t_in.keys()): + if len(c_t_in[element_text])>0: + co_text[element_text].append(np.array(c_t_in[element_text])) + + if 'graphicregions' in keys: + if tag.endswith('}GraphicRegion') or tag.endswith('}graphicregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in_graphic = {"handwritten-annotation":[], "decoration":[], "stamp":[], "signature":[]} + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + + if "rest_as_decoration" in types_graphic: + types_graphic_without_decoration = [element for element in types_graphic if element!='rest_as_decoration' and element!='decoration'] + if len(types_graphic_without_decoration) == 0: + if "type" in nn.attrib: + c_t_in_graphic['decoration'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + elif len(types_graphic_without_decoration) >= 1: + if "type" in nn.attrib: + if nn.attrib['type'] in types_graphic_without_decoration: + c_t_in_graphic[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + else: + c_t_in_graphic['decoration'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + else: + if "type" in nn.attrib: + if nn.attrib['type'] in all_defined_graphic_types: + c_t_in_graphic[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + break + else: + pass + + + if vv.tag==link+'Point': + if "rest_as_decoration" in types_graphic: + types_graphic_without_decoration = [element for element in types_graphic if element!='rest_as_decoration' and element!='decoration'] + if len(types_graphic_without_decoration) == 0: + if "type" in nn.attrib: + c_t_in_graphic['decoration'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + elif len(types_graphic_without_decoration) >= 1: + if "type" in nn.attrib: + if nn.attrib['type'] in types_graphic_without_decoration: + c_t_in_graphic[nn.attrib['type']].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + else: + c_t_in_graphic['decoration'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + + else: + if "type" in nn.attrib: + if nn.attrib['type'] in all_defined_graphic_types: + c_t_in_graphic[nn.attrib['type']].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + + elif vv.tag!=link+'Point' and sumi>=1: + break + + for element_graphic in list(c_t_in_graphic.keys()): + if len(c_t_in_graphic[element_graphic])>0: + co_graphic[element_graphic].append(np.array(c_t_in_graphic[element_graphic])) + + + if 'imageregion' in keys: + if tag.endswith('}ImageRegion') or tag.endswith('}imageregion'): + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + + elif vv.tag!=link+'Point' and sumi>=1: + break + co_img.append(np.array(c_t_in)) + + + if 'separatorregion' in keys: + if tag.endswith('}SeparatorRegion') or tag.endswith('}separatorregion'): + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + + elif vv.tag!=link+'Point' and sumi>=1: + break + co_sep.append(np.array(c_t_in)) + + + + if 'tableregion' in keys: + if tag.endswith('}TableRegion') or tag.endswith('}tableregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_table.append(np.array(c_t_in)) + + if 'noiseregion' in keys: + if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_noise.append(np.array(c_t_in)) + + if "artificial_class_on_boundary" in keys: + img_boundary = np.zeros( (y_len,x_len) ) + if "paragraph" in elements_with_artificial_class: + erosion_rate = 2 + dilation_rate = 4 + co_text['paragraph'], img_boundary = update_region_contours(co_text['paragraph'], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "drop-capital" in elements_with_artificial_class: + erosion_rate = 1 + dilation_rate = 3 + co_text["drop-capital"], img_boundary = update_region_contours(co_text["drop-capital"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "catch-word" in elements_with_artificial_class: + erosion_rate = 0 + dilation_rate = 3#4 + co_text["catch-word"], img_boundary = update_region_contours(co_text["catch-word"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "page-number" in elements_with_artificial_class: + erosion_rate = 0 + dilation_rate = 3#4 + co_text["page-number"], img_boundary = update_region_contours(co_text["page-number"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "header" in elements_with_artificial_class: + erosion_rate = 1 + dilation_rate = 4 + co_text["header"], img_boundary = update_region_contours(co_text["header"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "heading" in elements_with_artificial_class: + erosion_rate = 1 + dilation_rate = 4 + co_text["heading"], img_boundary = update_region_contours(co_text["heading"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "signature-mark" in elements_with_artificial_class: + erosion_rate = 1 + dilation_rate = 4 + co_text["signature-mark"], img_boundary = update_region_contours(co_text["signature-mark"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "marginalia" in elements_with_artificial_class: + erosion_rate = 2 + dilation_rate = 4 + co_text["marginalia"], img_boundary = update_region_contours(co_text["marginalia"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "footnote" in elements_with_artificial_class: + erosion_rate = 0#2 + dilation_rate = 2#4 + co_text["footnote"], img_boundary = update_region_contours(co_text["footnote"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "footnote-continued" in elements_with_artificial_class: + erosion_rate = 0#2 + dilation_rate = 2#4 + co_text["footnote-continued"], img_boundary = update_region_contours(co_text["footnote-continued"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "tableregion" in elements_with_artificial_class: + erosion_rate = 0#2 + dilation_rate = 3#4 + co_table, img_boundary = update_region_contours(co_table, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + + + + img = np.zeros( (y_len,x_len,3) ) + + if output_type == '3d': + if 'graphicregions' in keys: + if 'rest_as_decoration' in types_graphic: + types_graphic[types_graphic=='rest_as_decoration'] = 'decoration' + for element_graphic in types_graphic: + if element_graphic == 'decoration': + color_label = labels_rgb_color[ config_params['graphicregions']['rest_as_decoration']] + else: + color_label = labels_rgb_color[ config_params['graphicregions'][element_graphic]] + img_poly=cv2.fillPoly(img, pts =co_graphic[element_graphic], color=color_label) + else: + for element_graphic in types_graphic: + color_label = labels_rgb_color[ config_params['graphicregions'][element_graphic]] + img_poly=cv2.fillPoly(img, pts =co_graphic[element_graphic], color=color_label) + + + if 'imageregion' in keys: + img_poly=cv2.fillPoly(img, pts =co_img, color=labels_rgb_color[ config_params['imageregion']]) + if 'tableregion' in keys: + img_poly=cv2.fillPoly(img, pts =co_table, color=labels_rgb_color[ config_params['tableregion']]) + if 'noiseregion' in keys: + img_poly=cv2.fillPoly(img, pts =co_noise, color=labels_rgb_color[ config_params['noiseregion']]) + + if 'textregions' in keys: + if 'rest_as_paragraph' in types_text: + types_text = ['paragraph'if ttind=='rest_as_paragraph' else ttind for ttind in types_text] + for element_text in types_text: + if element_text == 'paragraph': + color_label = labels_rgb_color[ config_params['textregions']['rest_as_paragraph']] + else: + color_label = labels_rgb_color[ config_params['textregions'][element_text]] + img_poly=cv2.fillPoly(img, pts =co_text[element_text], color=color_label) + else: + for element_text in types_text: + color_label = labels_rgb_color[ config_params['textregions'][element_text]] + img_poly=cv2.fillPoly(img, pts =co_text[element_text], color=color_label) + + + if "artificial_class_on_boundary" in keys: + img_poly[:,:,0][img_boundary[:,:]==1] = artificial_class_rgb_color[0] + img_poly[:,:,1][img_boundary[:,:]==1] = artificial_class_rgb_color[1] + img_poly[:,:,2][img_boundary[:,:]==1] = artificial_class_rgb_color[2] + + if 'separatorregion' in keys: + img_poly=cv2.fillPoly(img, pts =co_sep, color=labels_rgb_color[ config_params['separatorregion']]) + + + if "printspace_as_class_in_layout" in list(config_params.keys()): + printspace_mask = np.zeros((img_poly.shape[0], img_poly.shape[1])) + printspace_mask[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2]] = 1 + + img_poly[:,:,0][printspace_mask[:,:] == 0] = printspace_class_rgb_color[0] + img_poly[:,:,1][printspace_mask[:,:] == 0] = printspace_class_rgb_color[1] + img_poly[:,:,2][printspace_mask[:,:] == 0] = printspace_class_rgb_color[2] + + + + + elif output_type == '2d': + if 'graphicregions' in keys: + if 'rest_as_decoration' in types_graphic: + types_graphic[types_graphic=='rest_as_decoration'] = 'decoration' + for element_graphic in types_graphic: + if element_graphic == 'decoration': + color_label = config_params['graphicregions']['rest_as_decoration'] + else: + color_label = config_params['graphicregions'][element_graphic] + img_poly=cv2.fillPoly(img, pts =co_graphic[element_graphic], color=color_label) + else: + for element_graphic in types_graphic: + color_label = config_params['graphicregions'][element_graphic] + img_poly=cv2.fillPoly(img, pts =co_graphic[element_graphic], color=color_label) + + + if 'imageregion' in keys: + color_label = config_params['imageregion'] + img_poly=cv2.fillPoly(img, pts =co_img, color=(color_label,color_label,color_label)) + if 'tableregion' in keys: + color_label = config_params['tableregion'] + img_poly=cv2.fillPoly(img, pts =co_table, color=(color_label,color_label,color_label)) + if 'noiseregion' in keys: + color_label = config_params['noiseregion'] + img_poly=cv2.fillPoly(img, pts =co_noise, color=(color_label,color_label,color_label)) + + if 'textregions' in keys: + if 'rest_as_paragraph' in types_text: + types_text = ['paragraph'if ttind=='rest_as_paragraph' else ttind for ttind in types_text] + for element_text in types_text: + if element_text == 'paragraph': + color_label = config_params['textregions']['rest_as_paragraph'] + else: + color_label = config_params['textregions'][element_text] + img_poly=cv2.fillPoly(img, pts =co_text[element_text], color=color_label) + else: + for element_text in types_text: + color_label = config_params['textregions'][element_text] + img_poly=cv2.fillPoly(img, pts =co_text[element_text], color=color_label) + + if "artificial_class_on_boundary" in keys: + img_poly[:,:][img_boundary[:,:]==1] = artificial_class_label + + if 'separatorregion' in keys: + color_label = config_params['separatorregion'] + img_poly=cv2.fillPoly(img, pts =co_sep, color=(color_label,color_label,color_label)) + + if "printspace_as_class_in_layout" in list(config_params.keys()): + printspace_mask = np.zeros((img_poly.shape[0], img_poly.shape[1])) + printspace_mask[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2]] = 1 + + img_poly[:,:,0][printspace_mask[:,:] == 0] = printspace_class_label + img_poly[:,:,1][printspace_mask[:,:] == 0] = printspace_class_label + img_poly[:,:,2][printspace_mask[:,:] == 0] = printspace_class_label + + + + if printspace: + img_poly = img_poly[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :] + + if 'columns_width' in list(config_params.keys()) and num_col: + img_poly = resize_image(img_poly, y_new, x_new) + + try: + xml_file_stem = os.path.splitext(gt_list[index])[0] + cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly) + except: + xml_file_stem = os.path.splitext(gt_list[index])[0] + cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly) + + + if dir_images: + org_image_name = ls_org_imgs[ls_org_imgs_stem.index(xml_file_stem)] + img_org = cv2.imread(os.path.join(dir_images, org_image_name)) + + if printspace: + img_org = img_org[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :] + + if 'columns_width' in list(config_params.keys()) and num_col: + img_org = resize_image(img_org, y_new, x_new) + + cv2.imwrite(os.path.join(dir_out_images, org_image_name), img_org) + + + +def find_new_features_of_contours(contours_main): + + areas_main = np.array([cv2.contourArea(contours_main[j]) for j in range(len(contours_main))]) + M_main = [cv2.moments(contours_main[j]) for j in range(len(contours_main))] + cx_main = [(M_main[j]["m10"] / (M_main[j]["m00"] + 1e-32)) for j in range(len(M_main))] + cy_main = [(M_main[j]["m01"] / (M_main[j]["m00"] + 1e-32)) for j in range(len(M_main))] + try: + x_min_main = np.array([np.min(contours_main[j][0][:, 0]) for j in range(len(contours_main))]) + + argmin_x_main = np.array([np.argmin(contours_main[j][0][:, 0]) for j in range(len(contours_main))]) + + x_min_from_argmin = np.array([contours_main[j][0][argmin_x_main[j], 0] for j in range(len(contours_main))]) + y_corr_x_min_from_argmin = np.array([contours_main[j][0][argmin_x_main[j], 1] for j in range(len(contours_main))]) + + x_max_main = np.array([np.max(contours_main[j][0][:, 0]) for j in range(len(contours_main))]) + + y_min_main = np.array([np.min(contours_main[j][0][:, 1]) for j in range(len(contours_main))]) + y_max_main = np.array([np.max(contours_main[j][0][:, 1]) for j in range(len(contours_main))]) + except: + x_min_main = np.array([np.min(contours_main[j][:, 0]) for j in range(len(contours_main))]) + + argmin_x_main = np.array([np.argmin(contours_main[j][:, 0]) for j in range(len(contours_main))]) + + x_min_from_argmin = np.array([contours_main[j][argmin_x_main[j], 0] for j in range(len(contours_main))]) + y_corr_x_min_from_argmin = np.array([contours_main[j][argmin_x_main[j], 1] for j in range(len(contours_main))]) + + x_max_main = np.array([np.max(contours_main[j][:, 0]) for j in range(len(contours_main))]) + + y_min_main = np.array([np.min(contours_main[j][:, 1]) for j in range(len(contours_main))]) + y_max_main = np.array([np.max(contours_main[j][:, 1]) for j in range(len(contours_main))]) + + return cx_main, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, y_corr_x_min_from_argmin +def read_xml(xml_file): + file_name = Path(xml_file).stem + tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8')) + root1=tree1.getroot() + alltags=[elem.tag for elem in root1.iter()] + link=alltags[0].split('}')[0]+'}' + + index_tot_regions = [] + tot_region_ref = [] + + for jj in root1.iter(link+'Page'): + y_len=int(jj.attrib['imageHeight']) + x_len=int(jj.attrib['imageWidth']) + + for jj in root1.iter(link+'RegionRefIndexed'): + index_tot_regions.append(jj.attrib['index']) + tot_region_ref.append(jj.attrib['regionRef']) + + if (link+'PrintSpace' in alltags) or (link+'Border' in alltags): + co_printspace = [] + if link+'PrintSpace' in alltags: + region_tags_printspace = np.unique([x for x in alltags if x.endswith('PrintSpace')]) + elif link+'Border' in alltags: + region_tags_printspace = np.unique([x for x in alltags if x.endswith('Border')]) + + for tag in region_tags_printspace: + if link+'PrintSpace' in alltags: + tag_endings_printspace = ['}PrintSpace','}printspace'] + elif link+'Border' in alltags: + tag_endings_printspace = ['}Border','}border'] + + if tag.endswith(tag_endings_printspace[0]) or tag.endswith(tag_endings_printspace[1]): + for nn in root1.iter(tag): + c_t_in = [] + sumi = 0 + for vv in nn.iter(): + # check the format of coords + if vv.tag == link + 'Coords': + coords = bool(vv.attrib) + if coords: + p_h = vv.attrib['points'].split(' ') + c_t_in.append( + np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])) + break + else: + pass + + if vv.tag == link + 'Point': + c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))]) + sumi += 1 + elif vv.tag != link + 'Point' and sumi >= 1: + break + co_printspace.append(np.array(c_t_in)) + img_printspace = np.zeros( (y_len,x_len,3) ) + img_printspace=cv2.fillPoly(img_printspace, pts =co_printspace, color=(1,1,1)) + img_printspace = img_printspace.astype(np.uint8) + + imgray = cv2.cvtColor(img_printspace, cv2.COLOR_BGR2GRAY) + _, thresh = cv2.threshold(imgray, 0, 255, 0) + contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + cnt_size = np.array([cv2.contourArea(contours[j]) for j in range(len(contours))]) + cnt = contours[np.argmax(cnt_size)] + x, y, w, h = cv2.boundingRect(cnt) + + bb_coord_printspace = [x, y, w, h] + + else: + bb_coord_printspace = None + + + region_tags=np.unique([x for x in alltags if x.endswith('Region')]) + co_text_paragraph=[] + co_text_drop=[] + co_text_heading=[] + co_text_header=[] + co_text_marginalia=[] + co_text_catch=[] + co_text_page_number=[] + co_text_signature_mark=[] + co_sep=[] + co_img=[] + co_table=[] + co_graphic=[] + co_graphic_text_annotation=[] + co_graphic_decoration=[] + co_noise=[] + + co_text_paragraph_text=[] + co_text_drop_text=[] + co_text_heading_text=[] + co_text_header_text=[] + co_text_marginalia_text=[] + co_text_catch_text=[] + co_text_page_number_text=[] + co_text_signature_mark_text=[] + co_sep_text=[] + co_img_text=[] + co_table_text=[] + co_graphic_text=[] + co_graphic_text_annotation_text=[] + co_graphic_decoration_text=[] + co_noise_text=[] + + id_paragraph = [] + id_header = [] + id_heading = [] + id_marginalia = [] + + for tag in region_tags: + if tag.endswith('}TextRegion') or tag.endswith('}Textregion'): + for nn in root1.iter(tag): + for child2 in nn: + tag2 = child2.tag + if tag2.endswith('}TextEquiv') or tag2.endswith('}TextEquiv'): + for childtext2 in child2: + if childtext2.tag.endswith('}Unicode') or childtext2.tag.endswith('}Unicode'): + if "type" in nn.attrib and nn.attrib['type']=='drop-capital': + co_text_drop_text.append(childtext2.text) + elif "type" in nn.attrib and nn.attrib['type']=='heading': + co_text_heading_text.append(childtext2.text) + elif "type" in nn.attrib and nn.attrib['type']=='signature-mark': + co_text_signature_mark_text.append(childtext2.text) + elif "type" in nn.attrib and nn.attrib['type']=='header': + co_text_header_text.append(childtext2.text) + ###elif "type" in nn.attrib and nn.attrib['type']=='catch-word': + ###co_text_catch_text.append(childtext2.text) + ###elif "type" in nn.attrib and nn.attrib['type']=='page-number': + ###co_text_page_number_text.append(childtext2.text) + elif "type" in nn.attrib and nn.attrib['type']=='marginalia': + co_text_marginalia_text.append(childtext2.text) + else: + co_text_paragraph_text.append(childtext2.text) + c_t_in_drop=[] + c_t_in_paragraph=[] + c_t_in_heading=[] + c_t_in_header=[] + c_t_in_page_number=[] + c_t_in_signature_mark=[] + c_t_in_catch=[] + c_t_in_marginalia=[] + + + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + + coords=bool(vv.attrib) + if coords: + #print('birda1') + p_h=vv.attrib['points'].split(' ') + + + + if "type" in nn.attrib and nn.attrib['type']=='drop-capital': + + c_t_in_drop.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + elif "type" in nn.attrib and nn.attrib['type']=='heading': + ##id_heading.append(nn.attrib['id']) + c_t_in_heading.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + + elif "type" in nn.attrib and nn.attrib['type']=='signature-mark': + + c_t_in_signature_mark.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + #print(c_t_in_paragraph) + elif "type" in nn.attrib and nn.attrib['type']=='header': + #id_header.append(nn.attrib['id']) + c_t_in_header.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + + ###elif "type" in nn.attrib and nn.attrib['type']=='catch-word': + ###c_t_in_catch.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + + ###elif "type" in nn.attrib and nn.attrib['type']=='page-number': + + ###c_t_in_page_number.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + elif "type" in nn.attrib and nn.attrib['type']=='marginalia': + #id_marginalia.append(nn.attrib['id']) + + c_t_in_marginalia.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + else: + #id_paragraph.append(nn.attrib['id']) + + c_t_in_paragraph.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + break + else: + pass + + + if vv.tag==link+'Point': + if "type" in nn.attrib and nn.attrib['type']=='drop-capital': + + c_t_in_drop.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + + elif "type" in nn.attrib and nn.attrib['type']=='heading': + #id_heading.append(nn.attrib['id']) + c_t_in_heading.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + + + elif "type" in nn.attrib and nn.attrib['type']=='signature-mark': + + c_t_in_signature_mark.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + elif "type" in nn.attrib and nn.attrib['type']=='header': + #id_header.append(nn.attrib['id']) + c_t_in_header.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + + + ###elif "type" in nn.attrib and nn.attrib['type']=='catch-word': + ###c_t_in_catch.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + ###sumi+=1 + + ###elif "type" in nn.attrib and nn.attrib['type']=='page-number': + + ###c_t_in_page_number.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + ###sumi+=1 + + elif "type" in nn.attrib and nn.attrib['type']=='marginalia': + #id_marginalia.append(nn.attrib['id']) + + c_t_in_marginalia.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + + else: + #id_paragraph.append(nn.attrib['id']) + c_t_in_paragraph.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + + elif vv.tag!=link+'Point' and sumi>=1: + break + + if len(c_t_in_drop)>0: + co_text_drop.append(np.array(c_t_in_drop)) + if len(c_t_in_paragraph)>0: + co_text_paragraph.append(np.array(c_t_in_paragraph)) + id_paragraph.append(nn.attrib['id']) + if len(c_t_in_heading)>0: + co_text_heading.append(np.array(c_t_in_heading)) + id_heading.append(nn.attrib['id']) + + if len(c_t_in_header)>0: + co_text_header.append(np.array(c_t_in_header)) + id_header.append(nn.attrib['id']) + if len(c_t_in_page_number)>0: + co_text_page_number.append(np.array(c_t_in_page_number)) + if len(c_t_in_catch)>0: + co_text_catch.append(np.array(c_t_in_catch)) + + if len(c_t_in_signature_mark)>0: + co_text_signature_mark.append(np.array(c_t_in_signature_mark)) + + if len(c_t_in_marginalia)>0: + co_text_marginalia.append(np.array(c_t_in_marginalia)) + id_marginalia.append(nn.attrib['id']) + + + elif tag.endswith('}GraphicRegion') or tag.endswith('}graphicregion'): + for nn in root1.iter(tag): + c_t_in=[] + c_t_in_text_annotation=[] + c_t_in_decoration=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + + if "type" in nn.attrib and nn.attrib['type']=='handwritten-annotation': + c_t_in_text_annotation.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + elif "type" in nn.attrib and nn.attrib['type']=='decoration': + c_t_in_decoration.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + else: + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + + break + else: + pass + + + if vv.tag==link+'Point': + if "type" in nn.attrib and nn.attrib['type']=='handwritten-annotation': + c_t_in_text_annotation.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + + elif "type" in nn.attrib and nn.attrib['type']=='decoration': + c_t_in_decoration.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + + else: + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + + if len(c_t_in_text_annotation)>0: + co_graphic_text_annotation.append(np.array(c_t_in_text_annotation)) + if len(c_t_in_decoration)>0: + co_graphic_decoration.append(np.array(c_t_in_decoration)) + if len(c_t_in)>0: + co_graphic.append(np.array(c_t_in)) + + + + elif tag.endswith('}ImageRegion') or tag.endswith('}imageregion'): + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + elif vv.tag!=link+'Point' and sumi>=1: + break + co_img.append(np.array(c_t_in)) + co_img_text.append(' ') + + + elif tag.endswith('}SeparatorRegion') or tag.endswith('}separatorregion'): + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + elif vv.tag!=link+'Point' and sumi>=1: + break + co_sep.append(np.array(c_t_in)) + + + + elif tag.endswith('}TableRegion') or tag.endswith('}tableregion'): + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + + elif vv.tag!=link+'Point' and sumi>=1: + break + co_table.append(np.array(c_t_in)) + co_table_text.append(' ') + + elif tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'): + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + + elif vv.tag!=link+'Point' and sumi>=1: + break + co_noise.append(np.array(c_t_in)) + co_noise_text.append(' ') + + img = np.zeros( (y_len,x_len,3) ) + img_poly=cv2.fillPoly(img, pts =co_text_paragraph, color=(1,1,1)) + + img_poly=cv2.fillPoly(img, pts =co_text_heading, color=(2,2,2)) + img_poly=cv2.fillPoly(img, pts =co_text_header, color=(2,2,2)) + img_poly=cv2.fillPoly(img, pts =co_text_marginalia, color=(3,3,3)) + img_poly=cv2.fillPoly(img, pts =co_img, color=(4,4,4)) + img_poly=cv2.fillPoly(img, pts =co_sep, color=(5,5,5)) + + return tree1, root1, bb_coord_printspace, file_name, id_paragraph, id_header+id_heading, co_text_paragraph, co_text_header+co_text_heading,\ +tot_region_ref,x_len, y_len,index_tot_regions, img_poly + + + + +def bounding_box(cnt,color, corr_order_index ): + x, y, w, h = cv2.boundingRect(cnt) + x = int(x*scale_w) + y = int(y*scale_h) + + w = int(w*scale_w) + h = int(h*scale_h) + + return [x,y,w,h,int(color), int(corr_order_index)+1] + +def resize_image(seg_in,input_height,input_width): + return cv2.resize(seg_in,(input_width,input_height),interpolation=cv2.INTER_NEAREST) + +def make_image_from_bb(width_l, height_l, bb_all): + bb_all =np.array(bb_all) + img_remade = np.zeros((height_l,width_l )) + + for i in range(bb_all.shape[0]): + img_remade[bb_all[i,1]:bb_all[i,1]+bb_all[i,3],bb_all[i,0]:bb_all[i,0]+bb_all[i,2] ] = 1 + return img_remade + +def update_list_and_return_first_with_length_bigger_than_one(index_element_to_be_updated, innner_index_pr_pos, pr_list, pos_list,list_inp): + list_inp.pop(index_element_to_be_updated) + if len(pr_list)>0: + list_inp.insert(index_element_to_be_updated, pr_list) + else: + index_element_to_be_updated = index_element_to_be_updated -1 + + list_inp.insert(index_element_to_be_updated+1, [innner_index_pr_pos]) + if len(pos_list)>0: + list_inp.insert(index_element_to_be_updated+2, pos_list) + + len_all_elements = [len(i) for i in list_inp] + list_len_bigger_1 = np.where(np.array(len_all_elements)>1) + list_len_bigger_1 = list_len_bigger_1[0] + + if len(list_len_bigger_1)>0: + early_list_bigger_than_one = list_len_bigger_1[0] + else: + early_list_bigger_than_one = -20 + return list_inp, early_list_bigger_than_one + +def overlay_layout_on_image(prediction, img, cx_ordered, cy_ordered, color, thickness): + + unique_classes = np.unique(prediction[:,:,0]) + rgb_colors = {'0' : [255, 255, 255], + '1' : [255, 0, 0], + '2' : [0, 0, 255], + '3' : [255, 0, 125], + '4' : [125, 125, 125], + '5' : [125, 125, 0], + '6' : [0, 125, 255], + '7' : [0, 125, 0], + '8' : [125, 125, 125], + '9' : [0, 125, 255], + '10' : [125, 0, 125], + '11' : [0, 255, 0], + '12' : [255, 125, 0], + '13' : [0, 255, 255], + '14' : [255, 125, 125], + '15' : [255, 0, 255]} + + layout_only = np.zeros(prediction.shape) + + for unq_class in unique_classes: + rgb_class_unique = rgb_colors[str(int(unq_class))] + layout_only[:,:,0][prediction[:,:,0]==unq_class] = rgb_class_unique[0] + layout_only[:,:,1][prediction[:,:,0]==unq_class] = rgb_class_unique[1] + layout_only[:,:,2][prediction[:,:,0]==unq_class] = rgb_class_unique[2] + + + + #img = self.resize_image(img, layout_only.shape[0], layout_only.shape[1]) + + layout_only = layout_only.astype(np.int32) + + for i in range(len(cx_ordered)-1): + start_point = (int(cx_ordered[i]), int(cy_ordered[i])) + end_point = (int(cx_ordered[i+1]), int(cy_ordered[i+1])) + layout_only = cv2.arrowedLine(layout_only, start_point, end_point, + color, thickness, tipLength = 0.03) + + img = img.astype(np.int32) + + + + added_image = cv2.addWeighted(img,0.5,layout_only,0.1,0) + + return added_image + +def find_format_of_given_filename_in_dir(dir_imgs, f_name): + ls_imgs = os.listdir(dir_imgs) + file_interested = [ind for ind in ls_imgs if ind.startswith(f_name+'.')] + return file_interested[0] diff --git a/train/inference.py b/train/inference.py new file mode 100644 index 0000000..094c528 --- /dev/null +++ b/train/inference.py @@ -0,0 +1,683 @@ +import sys +import os +import numpy as np +import warnings +import cv2 +import seaborn as sns +from tensorflow.keras.models import load_model +import tensorflow as tf +from tensorflow.keras import backend as K +from tensorflow.keras import layers +import tensorflow.keras.losses +from tensorflow.keras.layers import * +from models import * +from gt_gen_utils import * +import click +import json +from tensorflow.python.keras import backend as tensorflow_backend +import xml.etree.ElementTree as ET +import matplotlib.pyplot as plt + + +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + +__doc__=\ +""" +Tool to load model and predict for given image. +""" + +class sbb_predict: + def __init__(self,image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area): + self.image=image + self.dir_in=dir_in + self.patches=patches + self.save=save + self.save_layout=save_layout + self.model_dir=model + self.ground_truth=ground_truth + self.task=task + self.config_params_model=config_params_model + self.xml_file = xml_file + self.out = out + if min_area: + self.min_area = float(min_area) + else: + self.min_area = 0 + + def resize_image(self,img_in,input_height,input_width): + return cv2.resize( img_in, ( input_width,input_height) ,interpolation=cv2.INTER_NEAREST) + + + def color_images(self,seg): + ann_u=range(self.n_classes) + if len(np.shape(seg))==3: + seg=seg[:,:,0] + + seg_img=np.zeros((np.shape(seg)[0],np.shape(seg)[1],3)).astype(np.uint8) + colors=sns.color_palette("hls", self.n_classes) + + for c in ann_u: + c=int(c) + segl=(seg==c) + seg_img[:,:,0][seg==c]=c + seg_img[:,:,1][seg==c]=c + seg_img[:,:,2][seg==c]=c + return seg_img + + def otsu_copy_binary(self,img): + img_r=np.zeros((img.shape[0],img.shape[1],3)) + img1=img[:,:,0] + + #print(img.min()) + #print(img[:,:,0].min()) + #blur = cv2.GaussianBlur(img,(5,5)) + #ret3,th3 = cv2.threshold(blur,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU) + retval1, threshold1 = cv2.threshold(img1, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) + + + + img_r[:,:,0]=threshold1 + img_r[:,:,1]=threshold1 + img_r[:,:,2]=threshold1 + #img_r=img_r/float(np.max(img_r))*255 + return img_r + + def otsu_copy(self,img): + img_r=np.zeros((img.shape[0],img.shape[1],3)) + #img1=img[:,:,0] + + #print(img.min()) + #print(img[:,:,0].min()) + #blur = cv2.GaussianBlur(img,(5,5)) + #ret3,th3 = cv2.threshold(blur,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU) + _, threshold1 = cv2.threshold(img[:,:,0], 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) + _, threshold2 = cv2.threshold(img[:,:,1], 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) + _, threshold3 = cv2.threshold(img[:,:,2], 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) + + + + img_r[:,:,0]=threshold1 + img_r[:,:,1]=threshold2 + img_r[:,:,2]=threshold3 + ###img_r=img_r/float(np.max(img_r))*255 + return img_r + + def soft_dice_loss(self,y_true, y_pred, epsilon=1e-6): + + axes = tuple(range(1, len(y_pred.shape)-1)) + + numerator = 2. * K.sum(y_pred * y_true, axes) + + denominator = K.sum(K.square(y_pred) + K.square(y_true), axes) + return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch + + def weighted_categorical_crossentropy(self,weights=None): + + def loss(y_true, y_pred): + labels_floats = tf.cast(y_true, tf.float32) + per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred) + + if weights is not None: + weight_mask = tf.maximum(tf.reduce_max(tf.constant( + np.array(weights, dtype=np.float32)[None, None, None]) + * labels_floats, axis=-1), 1.0) + per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None] + return tf.reduce_mean(per_pixel_loss) + return self.loss + + + def IoU(self,Yi,y_predi): + ## mean Intersection over Union + ## Mean IoU = TP/(FN + TP + FP) + + IoUs = [] + Nclass = np.unique(Yi) + for c in Nclass: + TP = np.sum( (Yi == c)&(y_predi==c) ) + FP = np.sum( (Yi != c)&(y_predi==c) ) + FN = np.sum( (Yi == c)&(y_predi != c)) + IoU = TP/float(TP + FP + FN) + if self.n_classes>2: + print("class {:02.0f}: #TP={:6.0f}, #FP={:6.0f}, #FN={:5.0f}, IoU={:4.3f}".format(c,TP,FP,FN,IoU)) + IoUs.append(IoU) + if self.n_classes>2: + mIoU = np.mean(IoUs) + print("_________________") + print("Mean IoU: {:4.3f}".format(mIoU)) + return mIoU + elif self.n_classes==2: + mIoU = IoUs[1] + print("_________________") + print("IoU: {:4.3f}".format(mIoU)) + return mIoU + + def start_new_session_and_model(self): + + config = tf.compat.v1.ConfigProto() + config.gpu_options.allow_growth = True + + session = tf.compat.v1.Session(config=config) # tf.InteractiveSession() + tensorflow_backend.set_session(session) + #tensorflow.keras.layers.custom_layer = PatchEncoder + #tensorflow.keras.layers.custom_layer = Patches + self.model = load_model(self.model_dir , compile=False,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches}) + #config = tf.ConfigProto() + #config.gpu_options.allow_growth=True + + #self.session = tf.InteractiveSession() + #keras.losses.custom_loss = self.weighted_categorical_crossentropy + #self.model = load_model(self.model_dir , compile=False) + + + ##if self.weights_dir!=None: + ##self.model.load_weights(self.weights_dir) + + if (self.task != 'classification' and self.task != 'reading_order'): + self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1] + self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2] + self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3] + + def visualize_model_output(self, prediction, img, task): + if task == "binarization": + prediction = prediction * -1 + prediction = prediction + 1 + added_image = prediction * 255 + layout_only = None + else: + unique_classes = np.unique(prediction[:,:,0]) + rgb_colors = {'0' : [255, 255, 255], + '1' : [255, 0, 0], + '2' : [255, 125, 0], + '3' : [255, 0, 125], + '4' : [125, 125, 125], + '5' : [125, 125, 0], + '6' : [0, 125, 255], + '7' : [0, 125, 0], + '8' : [125, 125, 125], + '9' : [0, 125, 255], + '10' : [125, 0, 125], + '11' : [0, 255, 0], + '12' : [0, 0, 255], + '13' : [0, 255, 255], + '14' : [255, 125, 125], + '15' : [255, 0, 255]} + + layout_only = np.zeros(prediction.shape) + + for unq_class in unique_classes: + rgb_class_unique = rgb_colors[str(int(unq_class))] + layout_only[:,:,0][prediction[:,:,0]==unq_class] = rgb_class_unique[0] + layout_only[:,:,1][prediction[:,:,0]==unq_class] = rgb_class_unique[1] + layout_only[:,:,2][prediction[:,:,0]==unq_class] = rgb_class_unique[2] + + + + img = self.resize_image(img, layout_only.shape[0], layout_only.shape[1]) + + layout_only = layout_only.astype(np.int32) + img = img.astype(np.int32) + + + + added_image = cv2.addWeighted(img,0.5,layout_only,0.1,0) + + return added_image, layout_only + + def predict(self, image_dir): + if self.task == 'classification': + classes_names = self.config_params_model['classification_classes_name'] + img_1ch = img=cv2.imread(image_dir, 0) + + img_1ch = img_1ch / 255.0 + img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'], self.config_params_model['input_width']), interpolation=cv2.INTER_NEAREST) + img_in = np.zeros((1, img_1ch.shape[0], img_1ch.shape[1], 3)) + img_in[0, :, :, 0] = img_1ch[:, :] + img_in[0, :, :, 1] = img_1ch[:, :] + img_in[0, :, :, 2] = img_1ch[:, :] + + label_p_pred = self.model.predict(img_in, verbose=0) + index_class = np.argmax(label_p_pred[0]) + + print("Predicted Class: {}".format(classes_names[str(int(index_class))])) + elif self.task == 'reading_order': + img_height = self.config_params_model['input_height'] + img_width = self.config_params_model['input_width'] + + tree_xml, root_xml, bb_coord_printspace, file_name, id_paragraph, id_header, co_text_paragraph, co_text_header, tot_region_ref, x_len, y_len, index_tot_regions, img_poly = read_xml(self.xml_file) + _, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, _ = find_new_features_of_contours(co_text_header) + + img_header_and_sep = np.zeros((y_len,x_len), dtype='uint8') + + + for j in range(len(cy_main)): + img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12,int(x_min_main[j]):int(x_max_main[j]) ] = 1 + + co_text_all = co_text_paragraph + co_text_header + id_all_text = id_paragraph + id_header + + + ##texts_corr_order_index = [index_tot_regions[tot_region_ref.index(i)] for i in id_all_text ] + ##texts_corr_order_index_int = [int(x) for x in texts_corr_order_index] + texts_corr_order_index_int = list(np.array(range(len(co_text_all)))) + + #print(texts_corr_order_index_int) + + max_area = 1 + #print(np.shape(co_text_all[0]), len( np.shape(co_text_all[0]) ),'co_text_all') + #co_text_all = filter_contours_area_of_image_tables(img_poly, co_text_all, _, max_area, min_area) + #print(co_text_all,'co_text_all') + co_text_all, texts_corr_order_index_int, _ = filter_contours_area_of_image(img_poly, co_text_all, texts_corr_order_index_int, max_area, self.min_area) + + #print(texts_corr_order_index_int) + + #co_text_all = [co_text_all[index] for index in texts_corr_order_index_int] + id_all_text = [id_all_text[index] for index in texts_corr_order_index_int] + + labels_con = np.zeros((y_len,x_len,len(co_text_all)),dtype='uint8') + for i in range(len(co_text_all)): + img_label = np.zeros((y_len,x_len,3),dtype='uint8') + img_label=cv2.fillPoly(img_label, pts =[co_text_all[i]], color=(1,1,1)) + labels_con[:,:,i] = img_label[:,:,0] + + if bb_coord_printspace: + #bb_coord_printspace[x,y,w,h,_,_] + x = bb_coord_printspace[0] + y = bb_coord_printspace[1] + w = bb_coord_printspace[2] + h = bb_coord_printspace[3] + labels_con = labels_con[y:y+h, x:x+w, :] + img_poly = img_poly[y:y+h, x:x+w, :] + img_header_and_sep = img_header_and_sep[y:y+h, x:x+w] + + + + img3= np.copy(img_poly) + labels_con = resize_image(labels_con, img_height, img_width) + + img_header_and_sep = resize_image(img_header_and_sep, img_height, img_width) + + img3= resize_image (img3, img_height, img_width) + img3 = img3.astype(np.uint16) + + inference_bs = 1#4 + + input_1= np.zeros( (inference_bs, img_height, img_width,3)) + + + starting_list_of_regions = [] + starting_list_of_regions.append( list(range(labels_con.shape[2])) ) + + index_update = 0 + index_selected = starting_list_of_regions[0] + + scalibility_num = 0 + while index_update>=0: + ij_list = starting_list_of_regions[index_update] + i = ij_list[0] + ij_list.pop(0) + + + pr_list = [] + post_list = [] + + batch_counter = 0 + tot_counter = 1 + + tot_iteration = len(ij_list) + full_bs_ite= tot_iteration//inference_bs + last_bs = tot_iteration % inference_bs + + jbatch_indexer =[] + for j in ij_list: + img1= np.repeat(labels_con[:,:,i][:, :, np.newaxis], 3, axis=2) + img2 = np.repeat(labels_con[:,:,j][:, :, np.newaxis], 3, axis=2) + + + img2[:,:,0][img3[:,:,0]==5] = 2 + img2[:,:,0][img_header_and_sep[:,:]==1] = 3 + + + + img1[:,:,0][img3[:,:,0]==5] = 2 + img1[:,:,0][img_header_and_sep[:,:]==1] = 3 + + #input_1= np.zeros( (height1, width1,3)) + + + jbatch_indexer.append(j) + + input_1[batch_counter,:,:,0] = img1[:,:,0]/3. + input_1[batch_counter,:,:,2] = img2[:,:,0]/3. + input_1[batch_counter,:,:,1] = img3[:,:,0]/5. + #input_1[batch_counter,:,:,:]= np.zeros( (batch_counter, height1, width1,3)) + batch_counter = batch_counter+1 + + #input_1[:,:,0] = img1[:,:,0]/3. + #input_1[:,:,2] = img2[:,:,0]/3. + #input_1[:,:,1] = img3[:,:,0]/5. + + if batch_counter==inference_bs or ( (tot_counter//inference_bs)==full_bs_ite and tot_counter%inference_bs==last_bs): + y_pr = self.model.predict(input_1 , verbose=0) + scalibility_num = scalibility_num+1 + + if batch_counter==inference_bs: + iteration_batches = inference_bs + else: + iteration_batches = last_bs + for jb in range(iteration_batches): + if y_pr[jb][0]>=0.5: + post_list.append(jbatch_indexer[jb]) + else: + pr_list.append(jbatch_indexer[jb]) + + batch_counter = 0 + jbatch_indexer = [] + + tot_counter = tot_counter+1 + + starting_list_of_regions, index_update = update_list_and_return_first_with_length_bigger_than_one(index_update, i, pr_list, post_list,starting_list_of_regions) + + + index_sort = [i[0] for i in starting_list_of_regions ] + + id_all_text = np.array(id_all_text)[index_sort] + + alltags=[elem.tag for elem in root_xml.iter()] + + + + link=alltags[0].split('}')[0]+'}' + name_space = alltags[0].split('}')[0] + name_space = name_space.split('{')[1] + + page_element = root_xml.find(link+'Page') + + """ + ro_subelement = ET.SubElement(page_element, 'ReadingOrder') + #print(page_element, 'page_element') + + #new_element = ET.Element('ReadingOrder') + + new_element_element = ET.Element('OrderedGroup') + new_element_element.set('id', "ro357564684568544579089") + + for index, id_text in enumerate(id_all_text): + new_element_2 = ET.Element('RegionRefIndexed') + new_element_2.set('regionRef', id_all_text[index]) + new_element_2.set('index', str(index_sort[index])) + + new_element_element.append(new_element_2) + + ro_subelement.append(new_element_element) + """ + ##ro_subelement = ET.SubElement(page_element, 'ReadingOrder') + + ro_subelement = ET.Element('ReadingOrder') + + ro_subelement2 = ET.SubElement(ro_subelement, 'OrderedGroup') + ro_subelement2.set('id', "ro357564684568544579089") + + for index, id_text in enumerate(id_all_text): + new_element_2 = ET.SubElement(ro_subelement2, 'RegionRefIndexed') + new_element_2.set('regionRef', id_all_text[index]) + new_element_2.set('index', str(index)) + + if (link+'PrintSpace' in alltags) or (link+'Border' in alltags): + page_element.insert(1, ro_subelement) + else: + page_element.insert(0, ro_subelement) + + alltags=[elem.tag for elem in root_xml.iter()] + + ET.register_namespace("",name_space) + tree_xml.write(os.path.join(self.out, file_name+'.xml'),xml_declaration=True,method='xml',encoding="utf8",default_namespace=None) + #tree_xml.write('library2.xml') + + else: + if self.patches: + #def textline_contours(img,input_width,input_height,n_classes,model): + + img=cv2.imread(image_dir) + self.img_org = np.copy(img) + + if img.shape[0] < self.img_height: + img = self.resize_image(img, self.img_height, img.shape[1]) + + if img.shape[1] < self.img_width: + img = self.resize_image(img, img.shape[0], self.img_width) + + margin = int(0.1 * self.img_width) + width_mid = self.img_width - 2 * margin + height_mid = self.img_height - 2 * margin + img = img / float(255.0) + + img_h = img.shape[0] + img_w = img.shape[1] + + prediction_true = np.zeros((img_h, img_w, 3)) + nxf = img_w / float(width_mid) + nyf = img_h / float(height_mid) + + nxf = int(nxf) + 1 if nxf > int(nxf) else int(nxf) + nyf = int(nyf) + 1 if nyf > int(nyf) else int(nyf) + + for i in range(nxf): + for j in range(nyf): + if i == 0: + index_x_d = i * width_mid + index_x_u = index_x_d + self.img_width + else: + index_x_d = i * width_mid + index_x_u = index_x_d + self.img_width + if j == 0: + index_y_d = j * height_mid + index_y_u = index_y_d + self.img_height + else: + index_y_d = j * height_mid + index_y_u = index_y_d + self.img_height + + if index_x_u > img_w: + index_x_u = img_w + index_x_d = img_w - self.img_width + if index_y_u > img_h: + index_y_u = img_h + index_y_d = img_h - self.img_height + + img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :] + label_p_pred = self.model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2]), + verbose=0) + + if self.task == 'enhancement': + seg = label_p_pred[0, :, :, :] + seg = seg * 255 + elif self.task == 'segmentation' or self.task == 'binarization': + seg = np.argmax(label_p_pred, axis=3)[0] + seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2) + + + if i == 0 and j == 0: + seg = seg[0 : seg.shape[0] - margin, 0 : seg.shape[1] - margin] + prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + 0 : index_x_u - margin, :] = seg + elif i == nxf - 1 and j == nyf - 1: + seg = seg[margin : seg.shape[0] - 0, margin : seg.shape[1] - 0] + prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + margin : index_x_u - 0, :] = seg + elif i == 0 and j == nyf - 1: + seg = seg[margin : seg.shape[0] - 0, 0 : seg.shape[1] - margin] + prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + 0 : index_x_u - margin, :] = seg + elif i == nxf - 1 and j == 0: + seg = seg[0 : seg.shape[0] - margin, margin : seg.shape[1] - 0] + prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + margin : index_x_u - 0, :] = seg + elif i == 0 and j != 0 and j != nyf - 1: + seg = seg[margin : seg.shape[0] - margin, 0 : seg.shape[1] - margin] + prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + 0 : index_x_u - margin, :] = seg + elif i == nxf - 1 and j != 0 and j != nyf - 1: + seg = seg[margin : seg.shape[0] - margin, margin : seg.shape[1] - 0] + prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + margin : index_x_u - 0, :] = seg + elif i != 0 and i != nxf - 1 and j == 0: + seg = seg[0 : seg.shape[0] - margin, margin : seg.shape[1] - margin] + prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + margin : index_x_u - margin, :] = seg + elif i != 0 and i != nxf - 1 and j == nyf - 1: + seg = seg[margin : seg.shape[0] - 0, margin : seg.shape[1] - margin] + prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + margin : index_x_u - margin, :] = seg + else: + seg = seg[margin : seg.shape[0] - margin, margin : seg.shape[1] - margin] + prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + margin : index_x_u - margin, :] = seg + prediction_true = prediction_true.astype(int) + prediction_true = cv2.resize(prediction_true, (self.img_org.shape[1], self.img_org.shape[0]), interpolation=cv2.INTER_NEAREST) + return prediction_true + + else: + + img=cv2.imread(image_dir) + self.img_org = np.copy(img) + + width=self.img_width + height=self.img_height + + img=img/255.0 + img=self.resize_image(img,self.img_height,self.img_width) + + + label_p_pred=self.model.predict( + img.reshape(1,img.shape[0],img.shape[1],img.shape[2])) + + if self.task == 'enhancement': + seg = label_p_pred[0, :, :, :] + seg = seg * 255 + elif self.task == 'segmentation' or self.task == 'binarization': + seg = np.argmax(label_p_pred, axis=3)[0] + seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2) + + prediction_true = seg.astype(int) + + prediction_true = cv2.resize(prediction_true, (self.img_org.shape[1], self.img_org.shape[0]), interpolation=cv2.INTER_NEAREST) + return prediction_true + + + + def run(self): + self.start_new_session_and_model() + if self.image: + res=self.predict(image_dir = self.image) + + if (self.task == 'classification' or self.task == 'reading_order'): + pass + elif self.task == 'enhancement': + if self.save: + cv2.imwrite(self.save,res) + else: + img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task) + if self.save: + cv2.imwrite(self.save,img_seg_overlayed) + if self.save_layout: + cv2.imwrite(self.save_layout, only_layout) + + if self.ground_truth: + gt_img=cv2.imread(self.ground_truth) + self.IoU(gt_img[:,:,0],res[:,:,0]) + + else: + ls_images = os.listdir(self.dir_in) + for ind_image in ls_images: + f_name = ind_image.split('.')[0] + image_dir = os.path.join(self.dir_in, ind_image) + res=self.predict(image_dir) + + if (self.task == 'classification' or self.task == 'reading_order'): + pass + elif self.task == 'enhancement': + self.save = os.path.join(self.out, f_name+'.png') + cv2.imwrite(self.save,res) + else: + img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task) + self.save = os.path.join(self.out, f_name+'_overlayed.png') + cv2.imwrite(self.save,img_seg_overlayed) + self.save_layout = os.path.join(self.out, f_name+'_layout.png') + cv2.imwrite(self.save_layout, only_layout) + + if self.ground_truth: + gt_img=cv2.imread(self.ground_truth) + self.IoU(gt_img[:,:,0],res[:,:,0]) + + + +@click.command() +@click.option( + "--image", + "-i", + help="image filename", + type=click.Path(exists=True, dir_okay=False), +) +@click.option( + "--dir_in", + "-di", + help="directory of images", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--out", + "-o", + help="output directory where xml with detected reading order will be written.", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--patches/--no-patches", + "-p/-nop", + is_flag=True, + help="if this parameter set to true, this tool will try to do inference in patches.", +) +@click.option( + "--save", + "-s", + help="save prediction as a png file in current folder.", +) +@click.option( + "--save_layout", + "-sl", + help="save layout prediction only as a png file in current folder.", +) +@click.option( + "--model", + "-m", + help="directory of models", + type=click.Path(exists=True, file_okay=False), + required=True, +) +@click.option( + "--ground_truth", + "-gt", + help="ground truth directory if you want to see the iou of prediction.", +) +@click.option( + "--xml_file", + "-xml", + help="xml file with layout coordinates that reading order detection will be implemented on. The result will be written in the same xml file.", +) + +@click.option( + "--min_area", + "-min", + help="min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order.", +) +def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, out, min_area): + assert image or dir_in, "Either a single image -i or a dir_in -di is required" + with open(os.path.join(model,'config.json')) as f: + config_params_model = json.load(f) + task = config_params_model['task'] + if (task != 'classification' and task != 'reading_order'): + if image and not save: + print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s") + sys.exit(1) + if dir_in and not out: + print("Error: You used one of segmentation or binarization task with dir_in but not set -out") + sys.exit(1) + x=sbb_predict(image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area) + x.run() + +if __name__=="__main__": + main() + + + + diff --git a/train/metrics.py b/train/metrics.py new file mode 100644 index 0000000..cd30b02 --- /dev/null +++ b/train/metrics.py @@ -0,0 +1,357 @@ +from tensorflow.keras import backend as K +import tensorflow as tf +import numpy as np + + +def focal_loss(gamma=2., alpha=4.): + gamma = float(gamma) + alpha = float(alpha) + + def focal_loss_fixed(y_true, y_pred): + """Focal loss for multi-classification + FL(p_t)=-alpha(1-p_t)^{gamma}ln(p_t) + Notice: y_pred is probability after softmax + gradient is d(Fl)/d(p_t) not d(Fl)/d(x) as described in paper + d(Fl)/d(p_t) * [p_t(1-p_t)] = d(Fl)/d(x) + Focal Loss for Dense Object Detection + https://arxiv.org/abs/1708.02002 + + Arguments: + y_true {tensor} -- ground truth labels, shape of [batch_size, num_cls] + y_pred {tensor} -- model's output, shape of [batch_size, num_cls] + + Keyword Arguments: + gamma {float} -- (default: {2.0}) + alpha {float} -- (default: {4.0}) + + Returns: + [tensor] -- loss. + """ + epsilon = 1.e-9 + y_true = tf.convert_to_tensor(y_true, tf.float32) + y_pred = tf.convert_to_tensor(y_pred, tf.float32) + + model_out = tf.add(y_pred, epsilon) + ce = tf.multiply(y_true, -tf.log(model_out)) + weight = tf.multiply(y_true, tf.pow(tf.subtract(1., model_out), gamma)) + fl = tf.multiply(alpha, tf.multiply(weight, ce)) + reduced_fl = tf.reduce_max(fl, axis=1) + return tf.reduce_mean(reduced_fl) + + return focal_loss_fixed + + +def weighted_categorical_crossentropy(weights=None): + """ weighted_categorical_crossentropy + + Args: + * weights: crossentropy weights + Returns: + * weighted categorical crossentropy function + """ + + def loss(y_true, y_pred): + labels_floats = tf.cast(y_true, tf.float32) + per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats, logits=y_pred) + + if weights is not None: + weight_mask = tf.maximum(tf.reduce_max(tf.constant( + np.array(weights, dtype=np.float32)[None, None, None]) + * labels_floats, axis=-1), 1.0) + per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None] + return tf.reduce_mean(per_pixel_loss) + + return loss + + +def image_categorical_cross_entropy(y_true, y_pred, weights=None): + """ + :param y_true: tensor of shape (batch_size, height, width) representing the ground truth. + :param y_pred: tensor of shape (batch_size, height, width) representing the prediction. + :return: The mean cross-entropy on softmaxed tensors. + """ + + labels_floats = tf.cast(y_true, tf.float32) + per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats, logits=y_pred) + + if weights is not None: + weight_mask = tf.maximum( + tf.reduce_max(tf.constant( + np.array(weights, dtype=np.float32)[None, None, None]) + * labels_floats, axis=-1), 1.0) + per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None] + + return tf.reduce_mean(per_pixel_loss) + + +def class_tversky(y_true, y_pred): + smooth = 1.0 # 1.00 + + y_true = K.permute_dimensions(y_true, (3, 1, 2, 0)) + y_pred = K.permute_dimensions(y_pred, (3, 1, 2, 0)) + + y_true_pos = K.batch_flatten(y_true) + y_pred_pos = K.batch_flatten(y_pred) + true_pos = K.sum(y_true_pos * y_pred_pos, 1) + false_neg = K.sum(y_true_pos * (1 - y_pred_pos), 1) + false_pos = K.sum((1 - y_true_pos) * y_pred_pos, 1) + alpha = 0.2 # 0.5 + beta = 0.8 + return (true_pos + smooth) / (true_pos + alpha * false_neg + beta * false_pos + smooth) + + +def focal_tversky_loss(y_true, y_pred): + pt_1 = class_tversky(y_true, y_pred) + gamma = 1.3 # 4./3.0#1.3#4.0/3.00# 0.75 + return K.sum(K.pow((1 - pt_1), gamma)) + + +def generalized_dice_coeff2(y_true, y_pred): + n_el = 1 + for dim in y_true.shape: + n_el *= int(dim) + n_cl = y_true.shape[-1] + w = K.zeros(shape=(n_cl,)) + w = (K.sum(y_true, axis=(0, 1, 2))) / n_el + w = 1 / (w ** 2 + 0.000001) + numerator = y_true * y_pred + numerator = w * K.sum(numerator, (0, 1, 2)) + numerator = K.sum(numerator) + denominator = y_true + y_pred + denominator = w * K.sum(denominator, (0, 1, 2)) + denominator = K.sum(denominator) + return 2 * numerator / denominator + + +def generalized_dice_coeff(y_true, y_pred): + axes = tuple(range(1, len(y_pred.shape) - 1)) + Ncl = y_pred.shape[-1] + w = K.zeros(shape=(Ncl,)) + w = K.sum(y_true, axis=axes) + w = 1 / (w ** 2 + 0.000001) + # Compute gen dice coef: + numerator = y_true * y_pred + numerator = w * K.sum(numerator, axes) + numerator = K.sum(numerator) + + denominator = y_true + y_pred + denominator = w * K.sum(denominator, axes) + denominator = K.sum(denominator) + + gen_dice_coef = 2 * numerator / denominator + + return gen_dice_coef + + +def generalized_dice_loss(y_true, y_pred): + return 1 - generalized_dice_coeff2(y_true, y_pred) + + +def soft_dice_loss(y_true, y_pred, epsilon=1e-6): + """ + Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions. + Assumes the `channels_last` format. + + # Arguments + y_true: b x X x Y( x Z...) x c One hot encoding of ground truth + y_pred: b x X x Y( x Z...) x c Network output, must sum to 1 over c channel (such as after softmax) + epsilon: Used for numerical stability to avoid divide by zero errors + + # References + V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation + https://arxiv.org/abs/1606.04797 + More details on Dice loss formulation + https://mediatum.ub.tum.de/doc/1395260/1395260.pdf (page 72) + + Adapted from https://github.com/Lasagne/Recipes/issues/99#issuecomment-347775022 + """ + + # skip the batch and class axis for calculating Dice score + axes = tuple(range(1, len(y_pred.shape) - 1)) + + numerator = 2. * K.sum(y_pred * y_true, axes) + + denominator = K.sum(K.square(y_pred) + K.square(y_true), axes) + return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch + + +def seg_metrics(y_true, y_pred, metric_name, metric_type='standard', drop_last=True, mean_per_class=False, + verbose=False): + """ + Compute mean metrics of two segmentation masks, via Keras. + + IoU(A,B) = |A & B| / (| A U B|) + Dice(A,B) = 2*|A & B| / (|A| + |B|) + + Args: + y_true: true masks, one-hot encoded. + y_pred: predicted masks, either softmax outputs, or one-hot encoded. + metric_name: metric to be computed, either 'iou' or 'dice'. + metric_type: one of 'standard' (default), 'soft', 'naive'. + In the standard version, y_pred is one-hot encoded and the mean + is taken only over classes that are present (in y_true or y_pred). + The 'soft' version of the metrics are computed without one-hot + encoding y_pred. + The 'naive' version return mean metrics where absent classes contribute + to the class mean as 1.0 (instead of being dropped from the mean). + drop_last = True: boolean flag to drop last class (usually reserved + for background class in semantic segmentation) + mean_per_class = False: return mean along batch axis for each class. + verbose = False: print intermediate results such as intersection, union + (as number of pixels). + Returns: + IoU/Dice of y_true and y_pred, as a float, unless mean_per_class == True + in which case it returns the per-class metric, averaged over the batch. + + Inputs are B*W*H*N tensors, with + B = batch size, + W = width, + H = height, + N = number of classes + """ + + flag_soft = (metric_type == 'soft') + flag_naive_mean = (metric_type == 'naive') + + # always assume one or more classes + num_classes = K.shape(y_true)[-1] + + if not flag_soft: + # get one-hot encoded masks from y_pred (true masks should already be one-hot) + y_pred = K.one_hot(K.argmax(y_pred), num_classes) + y_true = K.one_hot(K.argmax(y_true), num_classes) + + # if already one-hot, could have skipped above command + # keras uses float32 instead of float64, would give error down (but numpy arrays or keras.to_categorical gives float64) + y_true = K.cast(y_true, 'float32') + y_pred = K.cast(y_pred, 'float32') + + # intersection and union shapes are batch_size * n_classes (values = area in pixels) + axes = (1, 2) # W,H axes of each image + intersection = K.sum(K.abs(y_true * y_pred), axis=axes) + mask_sum = K.sum(K.abs(y_true), axis=axes) + K.sum(K.abs(y_pred), axis=axes) + union = mask_sum - intersection # or, np.logical_or(y_pred, y_true) for one-hot + + smooth = .001 + iou = (intersection + smooth) / (union + smooth) + dice = 2 * (intersection + smooth) / (mask_sum + smooth) + + metric = {'iou': iou, 'dice': dice}[metric_name] + + # define mask to be 0 when no pixels are present in either y_true or y_pred, 1 otherwise + mask = K.cast(K.not_equal(union, 0), 'float32') + + if drop_last: + metric = metric[:, :-1] + mask = mask[:, :-1] + + if verbose: + print('intersection, union') + print(K.eval(intersection), K.eval(union)) + print(K.eval(intersection / union)) + + # return mean metrics: remaining axes are (batch, classes) + if flag_naive_mean: + return K.mean(metric) + + # take mean only over non-absent classes + class_count = K.sum(mask, axis=0) + non_zero = tf.greater(class_count, 0) + non_zero_sum = tf.boolean_mask(K.sum(metric * mask, axis=0), non_zero) + non_zero_count = tf.boolean_mask(class_count, non_zero) + + if verbose: + print('Counts of inputs with class present, metrics for non-absent classes') + print(K.eval(class_count), K.eval(non_zero_sum / non_zero_count)) + + return K.mean(non_zero_sum / non_zero_count) + + +def mean_iou(y_true, y_pred, **kwargs): + """ + Compute mean Intersection over Union of two segmentation masks, via Keras. + + Calls metrics_k(y_true, y_pred, metric_name='iou'), see there for allowed kwargs. + """ + return seg_metrics(y_true, y_pred, metric_name='iou', **kwargs) + + +def Mean_IOU(y_true, y_pred): + nb_classes = K.int_shape(y_pred)[-1] + iou = [] + true_pixels = K.argmax(y_true, axis=-1) + pred_pixels = K.argmax(y_pred, axis=-1) + void_labels = K.equal(K.sum(y_true, axis=-1), 0) + for i in range(0, nb_classes): # exclude first label (background) and last label (void) + true_labels = K.equal(true_pixels, i) # & ~void_labels + pred_labels = K.equal(pred_pixels, i) # & ~void_labels + inter = tf.to_int32(true_labels & pred_labels) + union = tf.to_int32(true_labels | pred_labels) + legal_batches = K.sum(tf.to_int32(true_labels), axis=1) > 0 + ious = K.sum(inter, axis=1) / K.sum(union, axis=1) + iou.append(K.mean(tf.gather(ious, indices=tf.where(legal_batches)))) # returns average IoU of the same objects + iou = tf.stack(iou) + legal_labels = ~tf.debugging.is_nan(iou) + iou = tf.gather(iou, indices=tf.where(legal_labels)) + return K.mean(iou) + + +def iou_vahid(y_true, y_pred): + nb_classes = tf.shape(y_true)[-1] + tf.to_int32(1) + true_pixels = K.argmax(y_true, axis=-1) + pred_pixels = K.argmax(y_pred, axis=-1) + iou = [] + + for i in tf.range(nb_classes): + tp = K.sum(tf.to_int32(K.equal(true_pixels, i) & K.equal(pred_pixels, i))) + fp = K.sum(tf.to_int32(K.not_equal(true_pixels, i) & K.equal(pred_pixels, i))) + fn = K.sum(tf.to_int32(K.equal(true_pixels, i) & K.not_equal(pred_pixels, i))) + iouh = tp / (tp + fp + fn) + iou.append(iouh) + return K.mean(iou) + + +def IoU_metric(Yi, y_predi): + # mean Intersection over Union + # Mean IoU = TP/(FN + TP + FP) + y_predi = np.argmax(y_predi, axis=3) + y_testi = np.argmax(Yi, axis=3) + IoUs = [] + Nclass = int(np.max(Yi)) + 1 + for c in range(Nclass): + TP = np.sum((Yi == c) & (y_predi == c)) + FP = np.sum((Yi != c) & (y_predi == c)) + FN = np.sum((Yi == c) & (y_predi != c)) + IoU = TP / float(TP + FP + FN) + IoUs.append(IoU) + return K.cast(np.mean(IoUs), dtype='float32') + + +def IoU_metric_keras(y_true, y_pred): + # mean Intersection over Union + # Mean IoU = TP/(FN + TP + FP) + init = tf.global_variables_initializer() + sess = tf.Session() + sess.run(init) + + return IoU_metric(y_true.eval(session=sess), y_pred.eval(session=sess)) + + +def jaccard_distance_loss(y_true, y_pred, smooth=100): + """ + Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|) + = sum(|A*B|)/(sum(|A|)+sum(|B|)-sum(|A*B|)) + + The jaccard distance loss is usefull for unbalanced datasets. This has been + shifted so it converges on 0 and is smoothed to avoid exploding or disapearing + gradient. + + Ref: https://en.wikipedia.org/wiki/Jaccard_index + + @url: https://gist.github.com/wassname/f1452b748efcbeb4cb9b1d059dce6f96 + @author: wassname + """ + intersection = K.sum(K.abs(y_true * y_pred), axis=-1) + sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1) + jac = (intersection + smooth) / (sum_ - intersection + smooth) + return (1 - jac) * smooth diff --git a/train/models.py b/train/models.py new file mode 100644 index 0000000..8841bd3 --- /dev/null +++ b/train/models.py @@ -0,0 +1,756 @@ +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras.models import * +from tensorflow.keras.layers import * +from tensorflow.keras import layers +from tensorflow.keras.regularizers import l2 + +##mlp_head_units = [512, 256]#[2048, 1024] +###projection_dim = 64 +##transformer_layers = 2#8 +##num_heads = 1#4 +resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5' +IMAGE_ORDERING = 'channels_last' +MERGE_AXIS = -1 + +def mlp(x, hidden_units, dropout_rate): + for units in hidden_units: + x = layers.Dense(units, activation=tf.nn.gelu)(x) + x = layers.Dropout(dropout_rate)(x) + return x + +class Patches(layers.Layer): + def __init__(self, patch_size_x, patch_size_y):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs): + super(Patches, self).__init__() + self.patch_size_x = patch_size_x + self.patch_size_y = patch_size_y + + def call(self, images): + #print(tf.shape(images)[1],'images') + #print(self.patch_size,'self.patch_size') + batch_size = tf.shape(images)[0] + patches = tf.image.extract_patches( + images=images, + sizes=[1, self.patch_size_y, self.patch_size_x, 1], + strides=[1, self.patch_size_y, self.patch_size_x, 1], + rates=[1, 1, 1, 1], + padding="VALID", + ) + #patch_dims = patches.shape[-1] + patch_dims = tf.shape(patches)[-1] + patches = tf.reshape(patches, [batch_size, -1, patch_dims]) + return patches + def get_config(self): + + config = super().get_config().copy() + config.update({ + 'patch_size_x': self.patch_size_x, + 'patch_size_y': self.patch_size_y, + }) + return config + +class Patches_old(layers.Layer): + def __init__(self, patch_size):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs): + super(Patches, self).__init__() + self.patch_size = patch_size + + def call(self, images): + #print(tf.shape(images)[1],'images') + #print(self.patch_size,'self.patch_size') + batch_size = tf.shape(images)[0] + patches = tf.image.extract_patches( + images=images, + sizes=[1, self.patch_size, self.patch_size, 1], + strides=[1, self.patch_size, self.patch_size, 1], + rates=[1, 1, 1, 1], + padding="VALID", + ) + patch_dims = patches.shape[-1] + #print(patches.shape,patch_dims,'patch_dims') + patches = tf.reshape(patches, [batch_size, -1, patch_dims]) + return patches + def get_config(self): + + config = super().get_config().copy() + config.update({ + 'patch_size': self.patch_size, + }) + return config + + +class PatchEncoder(layers.Layer): + def __init__(self, num_patches, projection_dim): + super(PatchEncoder, self).__init__() + self.num_patches = num_patches + self.projection = layers.Dense(units=projection_dim) + self.position_embedding = layers.Embedding( + input_dim=num_patches, output_dim=projection_dim + ) + + def call(self, patch): + positions = tf.range(start=0, limit=self.num_patches, delta=1) + encoded = self.projection(patch) + self.position_embedding(positions) + return encoded + def get_config(self): + + config = super().get_config().copy() + config.update({ + 'num_patches': self.num_patches, + 'projection': self.projection, + 'position_embedding': self.position_embedding, + }) + return config + + +def one_side_pad(x): + x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x) + if IMAGE_ORDERING == 'channels_first': + x = Lambda(lambda x: x[:, :, :-1, :-1])(x) + elif IMAGE_ORDERING == 'channels_last': + x = Lambda(lambda x: x[:, :-1, :-1, :])(x) + return x + + +def identity_block(input_tensor, kernel_size, filters, stage, block): + """The identity block is the block that has no conv layer at shortcut. + # Arguments + input_tensor: input tensor + kernel_size: defualt 3, the kernel size of middle conv layer at main path + filters: list of integers, the filterss of 3 conv layer at main path + stage: integer, current stage label, used for generating layer names + block: 'a','b'..., current block label, used for generating layer names + # Returns + Output tensor for the block. + """ + filters1, filters2, filters3 = filters + + if IMAGE_ORDERING == 'channels_last': + bn_axis = 3 + else: + bn_axis = 1 + + conv_name_base = 'res' + str(stage) + block + '_branch' + bn_name_base = 'bn' + str(stage) + block + '_branch' + + x = Conv2D(filters1, (1, 1), data_format=IMAGE_ORDERING, name=conv_name_base + '2a')(input_tensor) + x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) + x = Activation('relu')(x) + + x = Conv2D(filters2, kernel_size, data_format=IMAGE_ORDERING, + padding='same', name=conv_name_base + '2b')(x) + x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) + x = Activation('relu')(x) + + x = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, name=conv_name_base + '2c')(x) + x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) + + x = layers.add([x, input_tensor]) + x = Activation('relu')(x) + return x + + +def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)): + """conv_block is the block that has a conv layer at shortcut + # Arguments + input_tensor: input tensor + kernel_size: defualt 3, the kernel size of middle conv layer at main path + filters: list of integers, the filterss of 3 conv layer at main path + stage: integer, current stage label, used for generating layer names + block: 'a','b'..., current block label, used for generating layer names + # Returns + Output tensor for the block. + Note that from stage 3, the first conv layer at main path is with strides=(2,2) + And the shortcut should have strides=(2,2) as well + """ + filters1, filters2, filters3 = filters + + if IMAGE_ORDERING == 'channels_last': + bn_axis = 3 + else: + bn_axis = 1 + + conv_name_base = 'res' + str(stage) + block + '_branch' + bn_name_base = 'bn' + str(stage) + block + '_branch' + + x = Conv2D(filters1, (1, 1), data_format=IMAGE_ORDERING, strides=strides, + name=conv_name_base + '2a')(input_tensor) + x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) + x = Activation('relu')(x) + + x = Conv2D(filters2, kernel_size, data_format=IMAGE_ORDERING, padding='same', + name=conv_name_base + '2b')(x) + x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) + x = Activation('relu')(x) + + x = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, name=conv_name_base + '2c')(x) + x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) + + shortcut = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, strides=strides, + name=conv_name_base + '1')(input_tensor) + shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut) + + x = layers.add([x, shortcut]) + x = Activation('relu')(x) + return x + + +def resnet50_unet_light(n_classes, input_height=224, input_width=224, taks="segmentation", weight_decay=1e-6, pretraining=False): + assert input_height % 32 == 0 + assert input_width % 32 == 0 + + img_input = Input(shape=(input_height, input_width, 3)) + + if IMAGE_ORDERING == 'channels_last': + bn_axis = 3 + else: + bn_axis = 1 + + x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input) + x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2), kernel_regularizer=l2(weight_decay), + name='conv1')(x) + f1 = x + + x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) + x = Activation('relu')(x) + x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x) + + x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) + x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') + x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') + f2 = one_side_pad(x) + + x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') + f3 = x + + x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') + f4 = x + + x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') + f5 = x + + if pretraining: + model = Model(img_input, x).load_weights(resnet50_Weights_path) + + v512_2048 = Conv2D(512, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f5) + v512_2048 = (BatchNormalization(axis=bn_axis))(v512_2048) + v512_2048 = Activation('relu')(v512_2048) + + v512_1024 = Conv2D(512, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f4) + v512_1024 = (BatchNormalization(axis=bn_axis))(v512_1024) + v512_1024 = Activation('relu')(v512_1024) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(v512_2048) + o = (concatenate([o, v512_1024], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, f3], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, f2], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, f1], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, img_input], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o) + if task == "segmentation": + o = (BatchNormalization(axis=bn_axis))(o) + o = (Activation('softmax'))(o) + else: + o = (Activation('sigmoid'))(o) + + model = Model(img_input, o) + return model + + +def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): + assert input_height % 32 == 0 + assert input_width % 32 == 0 + + img_input = Input(shape=(input_height, input_width, 3)) + + if IMAGE_ORDERING == 'channels_last': + bn_axis = 3 + else: + bn_axis = 1 + + x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input) + x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2), kernel_regularizer=l2(weight_decay), + name='conv1')(x) + f1 = x + + x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) + x = Activation('relu')(x) + x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x) + + x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) + x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') + x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') + f2 = one_side_pad(x) + + x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') + f3 = x + + x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') + f4 = x + + x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') + f5 = x + + if pretraining: + Model(img_input, x).load_weights(resnet50_Weights_path) + + v1024_2048 = Conv2D(1024, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))( + f5) + v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048) + v1024_2048 = Activation('relu')(v1024_2048) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(v1024_2048) + o = (concatenate([o, f4], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, f3], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, f2], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, f1], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, img_input], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o) + if task == "segmentation": + o = (BatchNormalization(axis=bn_axis))(o) + o = (Activation('softmax'))(o) + else: + o = (Activation('sigmoid'))(o) + + model = Model(img_input, o) + + return model + + +def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=[128, 64], transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): + inputs = layers.Input(shape=(input_height, input_width, 3)) + + #transformer_units = [ + #projection_dim * 2, + #projection_dim, + #] # Size of the transformer layers + IMAGE_ORDERING = 'channels_last' + bn_axis=3 + + x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(inputs) + x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x) + f1 = x + + x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) + x = Activation('relu')(x) + x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x) + + x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) + x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') + x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') + f2 = one_side_pad(x) + + x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') + f3 = x + + x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') + f4 = x + + x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') + f5 = x + + if pretraining: + model = Model(inputs, x).load_weights(resnet50_Weights_path) + + #num_patches = x.shape[1]*x.shape[2] + + #patch_size_y = input_height / x.shape[1] + #patch_size_x = input_width / x.shape[2] + #patch_size = patch_size_x * patch_size_y + patches = Patches(patch_size_x, patch_size_y)(x) + # Encode patches. + encoded_patches = PatchEncoder(num_patches, projection_dim)(patches) + + for _ in range(transformer_layers): + # Layer normalization 1. + x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches) + # Create a multi-head attention layer. + attention_output = layers.MultiHeadAttention( + num_heads=num_heads, key_dim=projection_dim, dropout=0.1 + )(x1, x1) + # Skip connection 1. + x2 = layers.Add()([attention_output, encoded_patches]) + # Layer normalization 2. + x3 = layers.LayerNormalization(epsilon=1e-6)(x2) + # MLP. + x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1) + # Skip connection 2. + encoded_patches = layers.Add()([x3, x2]) + + encoded_patches = tf.reshape(encoded_patches, [-1, x.shape[1], x.shape[2] , int( projection_dim / (patch_size_x * patch_size_y) )]) + + v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(encoded_patches) + v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048) + v1024_2048 = Activation('relu')(v1024_2048) + + o = (UpSampling2D( (2, 2), data_format=IMAGE_ORDERING))(v1024_2048) + o = (concatenate([o, f4],axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o ,f3], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, f2], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, f1], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, inputs],axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(o) + if task == "segmentation": + o = (BatchNormalization(axis=bn_axis))(o) + o = (Activation('softmax'))(o) + else: + o = (Activation('sigmoid'))(o) + + model = Model(inputs=inputs, outputs=o) + + return model + +def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=[128, 64], transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): + inputs = layers.Input(shape=(input_height, input_width, 3)) + + ##transformer_units = [ + ##projection_dim * 2, + ##projection_dim, + ##] # Size of the transformer layers + IMAGE_ORDERING = 'channels_last' + bn_axis=3 + + patches = Patches(patch_size_x, patch_size_y)(inputs) + # Encode patches. + encoded_patches = PatchEncoder(num_patches, projection_dim)(patches) + + for _ in range(transformer_layers): + # Layer normalization 1. + x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches) + # Create a multi-head attention layer. + attention_output = layers.MultiHeadAttention( + num_heads=num_heads, key_dim=projection_dim, dropout=0.1 + )(x1, x1) + # Skip connection 1. + x2 = layers.Add()([attention_output, encoded_patches]) + # Layer normalization 2. + x3 = layers.LayerNormalization(epsilon=1e-6)(x2) + # MLP. + x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1) + # Skip connection 2. + encoded_patches = layers.Add()([x3, x2]) + + encoded_patches = tf.reshape(encoded_patches, [-1, input_height, input_width , int( projection_dim / (patch_size_x * patch_size_y) )]) + + encoded_patches = Conv2D(3, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay), name='convinput')(encoded_patches) + + x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(encoded_patches) + x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x) + f1 = x + + x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) + x = Activation('relu')(x) + x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x) + + x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) + x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') + x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') + f2 = one_side_pad(x) + + x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') + f3 = x + + x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') + f4 = x + + x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') + f5 = x + + if pretraining: + model = Model(encoded_patches, x).load_weights(resnet50_Weights_path) + + v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(x) + v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048) + v1024_2048 = Activation('relu')(v1024_2048) + + o = (UpSampling2D( (2, 2), data_format=IMAGE_ORDERING))(v1024_2048) + o = (concatenate([o, f4],axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o ,f3], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, f2], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, f1], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, inputs],axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(o) + if task == "segmentation": + o = (BatchNormalization(axis=bn_axis))(o) + o = (Activation('softmax'))(o) + else: + o = (Activation('sigmoid'))(o) + + model = Model(inputs=inputs, outputs=o) + + return model + +def resnet50_classifier(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False): + include_top=True + assert input_height%32 == 0 + assert input_width%32 == 0 + + + img_input = Input(shape=(input_height,input_width , 3 )) + + if IMAGE_ORDERING == 'channels_last': + bn_axis = 3 + else: + bn_axis = 1 + + x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input) + x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x) + f1 = x + + x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) + x = Activation('relu')(x) + x = MaxPooling2D((3, 3) , data_format=IMAGE_ORDERING , strides=(2, 2))(x) + + + x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) + x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') + x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') + f2 = one_side_pad(x ) + + + x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') + f3 = x + + x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') + f4 = x + + x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') + f5 = x + + if pretraining: + Model(img_input, x).load_weights(resnet50_Weights_path) + + x = AveragePooling2D((7, 7), name='avg_pool')(x) + x = Flatten()(x) + + ## + x = Dense(256, activation='relu', name='fc512')(x) + x=Dropout(0.2)(x) + ## + x = Dense(n_classes, activation='softmax', name='fc1000')(x) + model = Model(img_input, x) + + + + + return model + +def machine_based_reading_order_model(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False): + assert input_height%32 == 0 + assert input_width%32 == 0 + + img_input = Input(shape=(input_height,input_width , 3 )) + + if IMAGE_ORDERING == 'channels_last': + bn_axis = 3 + else: + bn_axis = 1 + + x1 = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input) + x1 = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x1) + + x1 = BatchNormalization(axis=bn_axis, name='bn_conv1')(x1) + x1 = Activation('relu')(x1) + x1 = MaxPooling2D((3, 3) , data_format=IMAGE_ORDERING , strides=(2, 2))(x1) + + x1 = conv_block(x1, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) + x1 = identity_block(x1, 3, [64, 64, 256], stage=2, block='b') + x1 = identity_block(x1, 3, [64, 64, 256], stage=2, block='c') + + x1 = conv_block(x1, 3, [128, 128, 512], stage=3, block='a') + x1 = identity_block(x1, 3, [128, 128, 512], stage=3, block='b') + x1 = identity_block(x1, 3, [128, 128, 512], stage=3, block='c') + x1 = identity_block(x1, 3, [128, 128, 512], stage=3, block='d') + + x1 = conv_block(x1, 3, [256, 256, 1024], stage=4, block='a') + x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='b') + x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='c') + x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='d') + x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='e') + x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='f') + + x1 = conv_block(x1, 3, [512, 512, 2048], stage=5, block='a') + x1 = identity_block(x1, 3, [512, 512, 2048], stage=5, block='b') + x1 = identity_block(x1, 3, [512, 512, 2048], stage=5, block='c') + + if pretraining: + Model(img_input , x1).load_weights(resnet50_Weights_path) + + x1 = AveragePooling2D((7, 7), name='avg_pool1')(x1) + flattened = Flatten()(x1) + + o = Dense(256, activation='relu', name='fc512')(flattened) + o=Dropout(0.2)(o) + + o = Dense(256, activation='relu', name='fc512a')(o) + o=Dropout(0.2)(o) + + o = Dense(n_classes, activation='sigmoid', name='fc1000')(o) + model = Model(img_input , o) + + return model diff --git a/train/requirements.txt b/train/requirements.txt new file mode 100644 index 0000000..d8f9003 --- /dev/null +++ b/train/requirements.txt @@ -0,0 +1,11 @@ +tensorflow == 2.12.1 +sacred +opencv-python-headless +seaborn +tqdm +imutils +numpy +scipy +scikit-learn +shapely +click diff --git a/train/scales_enhancement.json b/train/scales_enhancement.json new file mode 100644 index 0000000..58034f0 --- /dev/null +++ b/train/scales_enhancement.json @@ -0,0 +1,3 @@ +{ + "scales" : [ 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9] +} diff --git a/train/train.md b/train/train.md new file mode 100644 index 0000000..553522b --- /dev/null +++ b/train/train.md @@ -0,0 +1,576 @@ +# Documentation for Training Models + +This repository assists users in preparing training datasets, training models, and performing inference with trained models. We cover various use cases including pixel-wise segmentation, image classification, image enhancement, and machine-based reading order. For each use case, we provide guidance on how to generate the corresponding training dataset. +All these use cases are now utilized in the Eynollah workflow. +As mentioned, the following three tasks can be accomplished using this repository: + +* Generate training dataset +* Train a model +* Inference with the trained model + +## Generate training dataset +The script generate_gt_for_training.py is used for generating training datasets. As the results of the following command demonstrate, the dataset generator provides three different commands: + +`python generate_gt_for_training.py --help` + + +These three commands are: + +* image-enhancement +* machine-based-reading-order +* pagexml2label + + +### image-enhancement + +Generating a training dataset for image enhancement is quite straightforward. All that is needed is a set of high-resolution images. The training dataset can then be generated using the following command: + +`python generate_gt_for_training.py image-enhancement -dis "dir of high resolution images" -dois "dir where degraded images will be written" -dols "dir where the corresponding high resolution image will be written as label" -scs "degrading scales json file"` + +The scales JSON file is a dictionary with a key named 'scales' and values representing scales smaller than 1. Images are downscaled based on these scales and then upscaled again to their original size. This process causes the images to lose resolution at different scales. The degraded images are used as input images, and the original high-resolution images serve as labels. The enhancement model can be trained with this generated dataset. The scales JSON file looks like this: + +```yaml +{ + "scales": [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9] +} +``` + +### machine-based-reading-order + +For machine-based reading order, we aim to determine the reading priority between two sets of text regions. The model's input is a three-channel image: the first and last channels contain information about each of the two text regions, while the middle channel encodes prominent layout elements necessary for reading order, such as separators and headers. To generate the training dataset, our script requires a page XML file that specifies the image layout with the correct reading order. + +For output images, it is necessary to specify the width and height. Additionally, a minimum text region size can be set to filter out regions smaller than this minimum size. This minimum size is defined as the ratio of the text region area to the image area, with a default value of zero. To run the dataset generator, use the following command: + + +`python generate_gt_for_training.py machine-based-reading-order -dx "dir of GT xml files" -domi "dir where output images will be written" -docl "dir where the labels will be written" -ih "height" -iw "width" -min "min area ratio"` + +### pagexml2label + +pagexml2label is designed to generate labels from GT page XML files for various pixel-wise segmentation use cases, including 'layout,' 'textline,' 'printspace,' 'glyph,' and 'word' segmentation. +To train a pixel-wise segmentation model, we require images along with their corresponding labels. Our training script expects a PNG image where each pixel corresponds to a label, represented by an integer. The background is always labeled as zero, while other elements are assigned different integers. For instance, if we have ground truth data with four elements including the background, the classes would be labeled as 0, 1, 2, and 3 respectively. + +In binary segmentation scenarios such as textline or page extraction, the background is encoded as 0, and the desired element is automatically encoded as 1 in the PNG label. + +To specify the desired use case and the elements to be extracted in the PNG labels, a custom JSON file can be passed. For example, in the case of 'textline' detection, the JSON file would resemble this: + +```yaml +{ +"use_case": "textline" +} +``` + +In the case of layout segmentation a possible custom config json file can be like this: + +```yaml +{ +"use_case": "layout", +"textregions":{"rest_as_paragraph":1 , "drop-capital": 1, "header":2, "heading":2, "marginalia":3}, +"imageregion":4, +"separatorregion":5, +"graphicregions" :{"rest_as_decoration":6 ,"stamp":7} +} +``` + +A possible custom config json file for layout segmentation where the "printspace" is wished to be a class: + +```yaml +{ +"use_case": "layout", +"textregions":{"rest_as_paragraph":1 , "drop-capital": 1, "header":2, "heading":2, "marginalia":3}, +"imageregion":4, +"separatorregion":5, +"graphicregions" :{"rest_as_decoration":6 ,"stamp":7} +"printspace_as_class_in_layout" : 8 +} +``` +For the layout use case, it is beneficial to first understand the structure of the page XML file and its elements. In a given image, the annotations of elements are recorded in a page XML file, including their contours and classes. For an image document, the known regions are 'textregion', 'separatorregion', 'imageregion', 'graphicregion', 'noiseregion', and 'tableregion'. + +Text regions and graphic regions also have their own specific types. The known types for us for text regions are 'paragraph', 'header', 'heading', 'marginalia', 'drop-capital', 'footnote', 'footnote-continued', 'signature-mark', 'page-number', and 'catch-word'. The known types for graphic regions are 'handwritten-annotation', 'decoration', 'stamp', and 'signature'. +Since we don't know all types of text and graphic regions, unknown cases can arise. To handle these, we have defined two additional types: "rest_as_paragraph" and "rest_as_decoration" to ensure that no unknown types are missed. This way, users can extract all known types from the labels and be confident that no unknown types are overlooked. + +In the custom JSON file shown above, "header" and "heading" are extracted as the same class, while "marginalia" is shown as a different class. All other text region types, including "drop-capital," are grouped into the same class. For the graphic region, "stamp" has its own class, while all other types are classified together. "Image region" and "separator region" are also present in the label. However, other regions like "noise region" and "table region" will not be included in the label PNG file, even if they have information in the page XML files, as we chose not to include them. + +`python generate_gt_for_training.py pagexml2label -dx "dir of GT xml files" -do "dir where output label png files will be written" -cfg "custom config json file" -to "output type which has 2d and 3d. 2d is used for training and 3d is just to visualise the labels" "` + +We have also defined an artificial class that can be added to the boundary of text region types or text lines. This key is called "artificial_class_on_boundary." If users want to apply this to certain text regions in the layout use case, the example JSON config file should look like this: + +```yaml +{ + "use_case": "layout", + "textregions": { + "paragraph": 1, + "drop-capital": 1, + "header": 2, + "heading": 2, + "marginalia": 3 + }, + "imageregion": 4, + "separatorregion": 5, + "graphicregions": { + "rest_as_decoration": 6 + }, + "artificial_class_on_boundary": ["paragraph", "header", "heading", "marginalia"], + "artificial_class_label": 7 +} +``` + +This implies that the artificial class label, denoted by 7, will be present on PNG files and will only be added to the elements labeled as "paragraph," "header," "heading," and "marginalia." + +For "textline," "word," and "glyph," the artificial class on the boundaries will be activated only if the "artificial_class_label" key is specified in the config file. Its value should be set as 2 since these elements represent binary cases. For example, if the background and textline are denoted as 0 and 1 respectively, then the artificial class should be assigned the value 2. The example JSON config file should look like this for "textline" use case: + +```yaml +{ + "use_case": "textline", + "artificial_class_label": 2 +} +``` + +If the coordinates of "PrintSpace" or "Border" are present in the page XML ground truth files, and the user wishes to crop only the print space area, this can be achieved by activating the "-ps" argument. However, it should be noted that in this scenario, since cropping will be applied to the label files, the directory of the original images must be provided to ensure that they are cropped in sync with the labels. This ensures that the correct images and labels required for training are obtained. The command should resemble the following: + +`python generate_gt_for_training.py pagexml2label -dx "dir of GT xml files" -do "dir where output label png files will be written" -cfg "custom config json file" -to "output type which has 2d and 3d. 2d is used for training and 3d is just to visualise the labels" -ps -di "dir where the org images are located" -doi "dir where the cropped output images will be written" ` + +## Train a model +### classification + +For the classification use case, we haven't provided a ground truth generator, as it's unnecessary. For classification, all we require is a training directory with subdirectories, each containing images of its respective classes. We need separate directories for training and evaluation, and the class names (subdirectories) must be consistent across both directories. Additionally, the class names should be specified in the config JSON file, as shown in the following example. If, for instance, we aim to classify "apple" and "orange," with a total of 2 classes, the "classification_classes_name" key in the config file should appear as follows: + +```yaml +{ + "backbone_type" : "nontransformer", + "task": "classification", + "n_classes" : 2, + "n_epochs" : 10, + "input_height" : 448, + "input_width" : 448, + "weight_decay" : 1e-6, + "n_batch" : 4, + "learning_rate": 1e-4, + "f1_threshold_classification": 0.8, + "pretraining" : true, + "classification_classes_name" : {"0":"apple", "1":"orange"}, + "dir_train": "./train", + "dir_eval": "./eval", + "dir_output": "./output" +} +``` + +The "dir_train" should be like this: + +``` +. +└── train # train directory + ├── apple # directory of images for apple class + └── orange # directory of images for orange class +``` + +And the "dir_eval" the same structure as train directory: + +``` +. +└── eval # evaluation directory + ├── apple # directory of images for apple class + └── orange # directory of images for orange class + +``` + +The classification model can be trained using the following command line: + +`python train.py with config_classification.json` + + +As evident in the example JSON file above, for classification, we utilize a "f1_threshold_classification" parameter. This parameter is employed to gather all models with an evaluation f1 score surpassing this threshold. Subsequently, an ensemble of these model weights is executed, and a model is saved in the output directory as "model_ens_avg". Additionally, the weight of the best model based on the evaluation f1 score is saved as "model_best". + +### reading order +An example config json file for machine based reading order should be like this: + +```yaml +{ + "backbone_type" : "nontransformer", + "task": "reading_order", + "n_classes" : 1, + "n_epochs" : 5, + "input_height" : 672, + "input_width" : 448, + "weight_decay" : 1e-6, + "n_batch" : 4, + "learning_rate": 1e-4, + "pretraining" : true, + "dir_train": "./train", + "dir_eval": "./eval", + "dir_output": "./output" +} +``` + +The "dir_train" should be like this: + +``` +. +└── train # train directory + ├── images # directory of images + └── labels # directory of labels +``` + +And the "dir_eval" the same structure as train directory: + +``` +. +└── eval # evaluation directory + ├── images # directory of images + └── labels # directory of labels +``` + +The classification model can be trained like the classification case command line. + +### Segmentation (Textline, Binarization, Page extraction and layout) and enhancement + +#### Parameter configuration for segmentation or enhancement usecases + +The following parameter configuration can be applied to all segmentation use cases and enhancements. The augmentation, its sub-parameters, and continued training are defined only for segmentation use cases and enhancements, not for classification and machine-based reading order, as you can see in their example config files. + +* backbone_type: For segmentation tasks (such as text line, binarization, and layout detection) and enhancement, we offer two backbone options: a "nontransformer" and a "transformer" backbone. For the "transformer" backbone, we first apply a CNN followed by a transformer. In contrast, the "nontransformer" backbone utilizes only a CNN ResNet-50. +* task : The task parameter can have values such as "segmentation", "enhancement", "classification", and "reading_order". +* patches: If you want to break input images into smaller patches (input size of the model) you need to set this parameter to ``true``. In the case that the model should see the image once, like page extraction, patches should be set to ``false``. +* n_batch: Number of batches at each iteration. +* n_classes: Number of classes. In the case of binary classification this should be 2. In the case of reading_order it should set to 1. And for the case of layout detection just the unique number of classes should be given. +* n_epochs: Number of epochs. +* input_height: This indicates the height of model's input. +* input_width: This indicates the width of model's input. +* weight_decay: Weight decay of l2 regularization of model layers. +* pretraining: Set to ``true`` to load pretrained weights of ResNet50 encoder. The downloaded weights should be saved in a folder named "pretrained_model" in the same directory of "train.py" script. +* augmentation: If you want to apply any kind of augmentation this parameter should first set to ``true``. +* flip_aug: If ``true``, different types of filp will be applied on image. Type of flips is given with "flip_index" parameter. +* blur_aug: If ``true``, different types of blurring will be applied on image. Type of blurrings is given with "blur_k" parameter. +* scaling: If ``true``, scaling will be applied on image. Scale of scaling is given with "scales" parameter. +* degrading: If ``true``, degrading will be applied to the image. The amount of degrading is defined with "degrade_scales" parameter. +* brightening: If ``true``, brightening will be applied to the image. The amount of brightening is defined with "brightness" parameter. +* rotation_not_90: If ``true``, rotation (not 90 degree) will be applied on image. Rotation angles are given with "thetha" parameter. +* rotation: If ``true``, 90 degree rotation will be applied on image. +* binarization: If ``true``,Otsu thresholding will be applied to augment the input data with binarized images. +* scaling_bluring: If ``true``, combination of scaling and blurring will be applied on image. +* scaling_binarization: If ``true``, combination of scaling and binarization will be applied on image. +* scaling_flip: If ``true``, combination of scaling and flip will be applied on image. +* flip_index: Type of flips. +* blur_k: Type of blurrings. +* scales: Scales of scaling. +* brightness: The amount of brightenings. +* thetha: Rotation angles. +* degrade_scales: The amount of degradings. +* continue_training: If ``true``, it means that you have already trained a model and you would like to continue the training. So it is needed to provide the dir of trained model with "dir_of_start_model" and index for naming the models. For example if you have already trained for 3 epochs then your last index is 2 and if you want to continue from model_1.h5, you can set ``index_start`` to 3 to start naming model with index 3. +* weighted_loss: If ``true``, this means that you want to apply weighted categorical_crossentropy as loss fucntion. Be carefull if you set to ``true``the parameter "is_loss_soft_dice" should be ``false`` +* data_is_provided: If you have already provided the input data you can set this to ``true``. Be sure that the train and eval data are in "dir_output". Since when once we provide training data we resize and augment them and then we write them in sub-directories train and eval in "dir_output". +* dir_train: This is the directory of "images" and "labels" (dir_train should include two subdirectories with names of images and labels ) for raw images and labels. Namely they are not prepared (not resized and not augmented) yet for training the model. When we run this tool these raw data will be transformed to suitable size needed for the model and they will be written in "dir_output" in train and eval directories. Each of train and eval include "images" and "labels" sub-directories. +* index_start: Starting index for saved models in the case that "continue_training" is ``true``. +* dir_of_start_model: Directory containing pretrained model to continue training the model in the case that "continue_training" is ``true``. +* transformer_num_patches_xy: Number of patches for vision transformer in x and y direction respectively. +* transformer_patchsize_x: Patch size of vision transformer patches in x direction. +* transformer_patchsize_y: Patch size of vision transformer patches in y direction. +* transformer_projection_dim: Transformer projection dimension. Default value is 64. +* transformer_mlp_head_units: Transformer Multilayer Perceptron (MLP) head units. Default value is [128, 64]. +* transformer_layers: transformer layers. Default value is 8. +* transformer_num_heads: Transformer number of heads. Default value is 4. +* transformer_cnn_first: We have two types of vision transformers. In one type, a CNN is applied first, followed by a transformer. In the other type, this order is reversed. If transformer_cnn_first is true, it means the CNN will be applied before the transformer. Default value is true. + +In the case of segmentation and enhancement the train and evaluation directory should be as following. + +The "dir_train" should be like this: + +``` +. +└── train # train directory + ├── images # directory of images + └── labels # directory of labels +``` + +And the "dir_eval" the same structure as train directory: + +``` +. +└── eval # evaluation directory + ├── images # directory of images + └── labels # directory of labels +``` + +After configuring the JSON file for segmentation or enhancement, training can be initiated by running the following command, similar to the process for classification and reading order: + +`python train.py with config_classification.json` + +#### Binarization + +An example config json file for binarization can be like this: + +```yaml +{ + "backbone_type" : "transformer", + "task": "binarization", + "n_classes" : 2, + "n_epochs" : 4, + "input_height" : 224, + "input_width" : 672, + "weight_decay" : 1e-6, + "n_batch" : 1, + "learning_rate": 1e-4, + "patches" : true, + "pretraining" : true, + "augmentation" : true, + "flip_aug" : false, + "blur_aug" : false, + "scaling" : true, + "degrading": false, + "brightening": false, + "binarization" : false, + "scaling_bluring" : false, + "scaling_binarization" : false, + "scaling_flip" : false, + "rotation": false, + "rotation_not_90": false, + "transformer_num_patches_xy": [7, 7], + "transformer_patchsize_x": 3, + "transformer_patchsize_y": 1, + "transformer_projection_dim": 192, + "transformer_mlp_head_units": [128, 64], + "transformer_layers": 8, + "transformer_num_heads": 4, + "transformer_cnn_first": true, + "blur_k" : ["blur","guass","median"], + "scales" : [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.4], + "brightness" : [1.3, 1.5, 1.7, 2], + "degrade_scales" : [0.2, 0.4], + "flip_index" : [0, 1, -1], + "thetha" : [10, -10], + "continue_training": false, + "index_start" : 0, + "dir_of_start_model" : " ", + "weighted_loss": false, + "is_loss_soft_dice": false, + "data_is_provided": false, + "dir_train": "./train", + "dir_eval": "./eval", + "dir_output": "./output" +} +``` + +#### Textline + +```yaml +{ + "backbone_type" : "nontransformer", + "task": "segmentation", + "n_classes" : 2, + "n_epochs" : 4, + "input_height" : 448, + "input_width" : 224, + "weight_decay" : 1e-6, + "n_batch" : 1, + "learning_rate": 1e-4, + "patches" : true, + "pretraining" : true, + "augmentation" : true, + "flip_aug" : false, + "blur_aug" : false, + "scaling" : true, + "degrading": false, + "brightening": false, + "binarization" : false, + "scaling_bluring" : false, + "scaling_binarization" : false, + "scaling_flip" : false, + "rotation": false, + "rotation_not_90": false, + "blur_k" : ["blur","guass","median"], + "scales" : [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.4], + "brightness" : [1.3, 1.5, 1.7, 2], + "degrade_scales" : [0.2, 0.4], + "flip_index" : [0, 1, -1], + "thetha" : [10, -10], + "continue_training": false, + "index_start" : 0, + "dir_of_start_model" : " ", + "weighted_loss": false, + "is_loss_soft_dice": false, + "data_is_provided": false, + "dir_train": "./train", + "dir_eval": "./eval", + "dir_output": "./output" +} +``` + +#### Enhancement + +```yaml +{ + "backbone_type" : "nontransformer", + "task": "enhancement", + "n_classes" : 3, + "n_epochs" : 4, + "input_height" : 448, + "input_width" : 224, + "weight_decay" : 1e-6, + "n_batch" : 4, + "learning_rate": 1e-4, + "patches" : true, + "pretraining" : true, + "augmentation" : true, + "flip_aug" : false, + "blur_aug" : false, + "scaling" : true, + "degrading": false, + "brightening": false, + "binarization" : false, + "scaling_bluring" : false, + "scaling_binarization" : false, + "scaling_flip" : false, + "rotation": false, + "rotation_not_90": false, + "blur_k" : ["blur","guass","median"], + "scales" : [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.4], + "brightness" : [1.3, 1.5, 1.7, 2], + "degrade_scales" : [0.2, 0.4], + "flip_index" : [0, 1, -1], + "thetha" : [10, -10], + "continue_training": false, + "index_start" : 0, + "dir_of_start_model" : " ", + "weighted_loss": false, + "is_loss_soft_dice": false, + "data_is_provided": false, + "dir_train": "./train", + "dir_eval": "./eval", + "dir_output": "./output" +} +``` + +It's important to mention that the value of n_classes for enhancement should be 3, as the model's output is a 3-channel image. + +#### Page extraction + +```yaml +{ + "backbone_type" : "nontransformer", + "task": "segmentation", + "n_classes" : 2, + "n_epochs" : 4, + "input_height" : 448, + "input_width" : 224, + "weight_decay" : 1e-6, + "n_batch" : 1, + "learning_rate": 1e-4, + "patches" : false, + "pretraining" : true, + "augmentation" : false, + "flip_aug" : false, + "blur_aug" : false, + "scaling" : true, + "degrading": false, + "brightening": false, + "binarization" : false, + "scaling_bluring" : false, + "scaling_binarization" : false, + "scaling_flip" : false, + "rotation": false, + "rotation_not_90": false, + "blur_k" : ["blur","guass","median"], + "scales" : [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.4], + "brightness" : [1.3, 1.5, 1.7, 2], + "degrade_scales" : [0.2, 0.4], + "flip_index" : [0, 1, -1], + "thetha" : [10, -10], + "continue_training": false, + "index_start" : 0, + "dir_of_start_model" : " ", + "weighted_loss": false, + "is_loss_soft_dice": false, + "data_is_provided": false, + "dir_train": "./train", + "dir_eval": "./eval", + "dir_output": "./output" +} +``` + +For page segmentation (or print space or border segmentation), the model needs to view the input image in its entirety, hence the patches parameter should be set to false. + +#### layout segmentation + +An example config json file for layout segmentation with 5 classes (including background) can be like this: + +```yaml +{ + "backbone_type" : "transformer", + "task": "segmentation", + "n_classes" : 5, + "n_epochs" : 4, + "input_height" : 448, + "input_width" : 224, + "weight_decay" : 1e-6, + "n_batch" : 1, + "learning_rate": 1e-4, + "patches" : true, + "pretraining" : true, + "augmentation" : true, + "flip_aug" : false, + "blur_aug" : false, + "scaling" : true, + "degrading": false, + "brightening": false, + "binarization" : false, + "scaling_bluring" : false, + "scaling_binarization" : false, + "scaling_flip" : false, + "rotation": false, + "rotation_not_90": false, + "transformer_num_patches_xy": [7, 14], + "transformer_patchsize_x": 1, + "transformer_patchsize_y": 1, + "transformer_projection_dim": 64, + "transformer_mlp_head_units": [128, 64], + "transformer_layers": 8, + "transformer_num_heads": 4, + "transformer_cnn_first": true, + "blur_k" : ["blur","guass","median"], + "scales" : [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.4], + "brightness" : [1.3, 1.5, 1.7, 2], + "degrade_scales" : [0.2, 0.4], + "flip_index" : [0, 1, -1], + "thetha" : [10, -10], + "continue_training": false, + "index_start" : 0, + "dir_of_start_model" : " ", + "weighted_loss": false, + "is_loss_soft_dice": false, + "data_is_provided": false, + "dir_train": "./train", + "dir_eval": "./eval", + "dir_output": "./output" +} +``` +## Inference with the trained model +### classification + +For conducting inference with a trained model, you simply need to execute the following command line, specifying the directory of the model and the image on which to perform inference: + + +`python inference.py -m "model dir" -i "image" ` + +This will straightforwardly return the class of the image. + +### machine based reading order + + +To infer the reading order using an reading order model, we need a page XML file containing layout information but without the reading order. We simply need to provide the model directory, the XML file, and the output directory. The new XML file with the added reading order will be written to the output directory with the same name. We need to run: + +`python inference.py -m "model dir" -xml "page xml file" -o "output dir to write new xml with reading order" ` + + +### Segmentation (Textline, Binarization, Page extraction and layout) and enhancement + +For conducting inference with a trained model for segmentation and enhancement you need to run the following command line: + + +`python inference.py -m "model dir" -i "image" -p -s "output image" ` + + +Note that in the case of page extraction the -p flag is not needed. + +For segmentation or binarization tasks, if a ground truth (GT) label is available, the IOU evaluation metric can be calculated for the output. To do this, you need to provide the GT label using the argument -gt. + + + diff --git a/train/train.py b/train/train.py new file mode 100644 index 0000000..e8e92af --- /dev/null +++ b/train/train.py @@ -0,0 +1,451 @@ +import os +import sys +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' +import tensorflow as tf +from tensorflow.compat.v1.keras.backend import set_session +import warnings +from tensorflow.keras.optimizers import * +from sacred import Experiment +from models import * +from utils import * +from metrics import * +from tensorflow.keras.models import load_model +from tqdm import tqdm +import json +from sklearn.metrics import f1_score +from tensorflow.keras.callbacks import Callback + +class SaveWeightsAfterSteps(Callback): + def __init__(self, save_interval, save_path, _config): + super(SaveWeightsAfterSteps, self).__init__() + self.save_interval = save_interval + self.save_path = save_path + self.step_count = 0 + self._config = _config + + def on_train_batch_end(self, batch, logs=None): + self.step_count += 1 + + if self.step_count % self.save_interval ==0: + save_file = f"{self.save_path}/model_step_{self.step_count}" + #os.system('mkdir '+save_file) + + self.model.save(save_file) + + with open(os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"config.json"), "w") as fp: + json.dump(self._config, fp) # encode dict into JSON + print(f"saved model as steps {self.step_count} to {save_file}") + + +def configuration(): + config = tf.compat.v1.ConfigProto() + config.gpu_options.allow_growth = True + session = tf.compat.v1.Session(config=config) + set_session(session) + + +def get_dirs_or_files(input_data): + if os.path.isdir(input_data): + image_input, labels_input = os.path.join(input_data, 'images/'), os.path.join(input_data, 'labels/') + # Check if training dir exists + assert os.path.isdir(image_input), "{} is not a directory".format(image_input) + assert os.path.isdir(labels_input), "{} is not a directory".format(labels_input) + return image_input, labels_input + + +ex = Experiment(save_git_info=False) + + +@ex.config +def config_params(): + n_classes = None # Number of classes. In the case of binary classification this should be 2. + n_epochs = 1 # Number of epochs. + input_height = 224 * 1 # Height of model's input in pixels. + input_width = 224 * 1 # Width of model's input in pixels. + weight_decay = 1e-6 # Weight decay of l2 regularization of model layers. + n_batch = 1 # Number of batches at each iteration. + learning_rate = 1e-4 # Set the learning rate. + patches = False # Divides input image into smaller patches (input size of the model) when set to true. For the model to see the full image, like page extraction, set this to false. + augmentation = False # To apply any kind of augmentation, this parameter must be set to true. + flip_aug = False # If true, different types of flipping will be applied to the image. Types of flips are defined with "flip_index" in config_params.json. + blur_aug = False # If true, different types of blurring will be applied to the image. Types of blur are defined with "blur_k" in config_params.json. + padding_white = False # If true, white padding will be applied to the image. + padding_black = False # If true, black padding will be applied to the image. + scaling = False # If true, scaling will be applied to the image. The amount of scaling is defined with "scales" in config_params.json. + shifting = False + degrading = False # If true, degrading will be applied to the image. The amount of degrading is defined with "degrade_scales" in config_params.json. + brightening = False # If true, brightening will be applied to the image. The amount of brightening is defined with "brightness" in config_params.json. + binarization = False # If true, Otsu thresholding will be applied to augment the input with binarized images. + adding_rgb_background = False + adding_rgb_foreground = False + add_red_textlines = False + channels_shuffling = False + dir_train = None # Directory of training dataset with subdirectories having the names "images" and "labels". + dir_eval = None # Directory of validation dataset with subdirectories having the names "images" and "labels". + dir_output = None # Directory where the output model will be saved. + pretraining = False # Set to true to load pretrained weights of ResNet50 encoder. + scaling_bluring = False # If true, a combination of scaling and blurring will be applied to the image. + scaling_binarization = False # If true, a combination of scaling and binarization will be applied to the image. + rotation = False # If true, a 90 degree rotation will be implemeneted. + rotation_not_90 = False # If true rotation based on provided angles with thetha will be implemeneted. + scaling_brightness = False # If true, a combination of scaling and brightening will be applied to the image. + scaling_flip = False # If true, a combination of scaling and flipping will be applied to the image. + thetha = None # Rotate image by these angles for augmentation. + shuffle_indexes = None + blur_k = None # Blur image for augmentation. + scales = None # Scale patches for augmentation. + degrade_scales = None # Degrade image for augmentation. + brightness = None # Brighten image for augmentation. + flip_index = None # Flip image for augmentation. + continue_training = False # Set to true if you would like to continue training an already trained a model. + transformer_patchsize_x = None # Patch size of vision transformer patches in x direction. + transformer_patchsize_y = None # Patch size of vision transformer patches in y direction. + transformer_num_patches_xy = None # Number of patches for vision transformer in x and y direction respectively. + transformer_projection_dim = 64 # Transformer projection dimension. Default value is 64. + transformer_mlp_head_units = [128, 64] # Transformer Multilayer Perceptron (MLP) head units. Default value is [128, 64] + transformer_layers = 8 # transformer layers. Default value is 8. + transformer_num_heads = 4 # Transformer number of heads. Default value is 4. + transformer_cnn_first = True # We have two types of vision transformers. In one type, a CNN is applied first, followed by a transformer. In the other type, this order is reversed. If transformer_cnn_first is true, it means the CNN will be applied before the transformer. Default value is true. + index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3. + dir_of_start_model = '' # Directory containing pretrained encoder to continue training the model. + is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false. + weighted_loss = False # Use weighted categorical cross entropy as loss fucntion. When set to true, "is_loss_soft_dice" must be false. + data_is_provided = False # Only set this to true when you have already provided the input data and the train and eval data are in "dir_output". + task = "segmentation" # This parameter defines task of model which can be segmentation, enhancement or classification. + f1_threshold_classification = None # This threshold is used to consider models with an evaluation f1 scores bigger than it. The selected model weights undergo a weights ensembling. And avreage ensembled model will be written to output. + classification_classes_name = None # Dictionary of classification classes names. + backbone_type = None # As backbone we have 2 types of backbones. A vision transformer alongside a CNN and we call it "transformer" and only CNN called "nontransformer" + save_interval = None + dir_img_bin = None + number_of_backgrounds_per_image = 1 + dir_rgb_backgrounds = None + dir_rgb_foregrounds = None + + +@ex.automain +def run(_config, n_classes, n_epochs, input_height, + input_width, weight_decay, weighted_loss, + index_start, dir_of_start_model, is_loss_soft_dice, + n_batch, patches, augmentation, flip_aug, + blur_aug, padding_white, padding_black, scaling, shifting, degrading,channels_shuffling, + brightening, binarization, adding_rgb_background, adding_rgb_foreground, add_red_textlines, blur_k, scales, degrade_scales,shuffle_indexes, + brightness, dir_train, data_is_provided, scaling_bluring, + scaling_brightness, scaling_binarization, rotation, rotation_not_90, + thetha, scaling_flip, continue_training, transformer_projection_dim, + transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_cnn_first, + transformer_patchsize_x, transformer_patchsize_y, + transformer_num_patches_xy, backbone_type, save_interval, flip_index, dir_eval, dir_output, + pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds): + + if dir_rgb_backgrounds: + list_all_possible_background_images = os.listdir(dir_rgb_backgrounds) + else: + list_all_possible_background_images = None + + if dir_rgb_foregrounds: + list_all_possible_foreground_rgbs = os.listdir(dir_rgb_foregrounds) + else: + list_all_possible_foreground_rgbs = None + + if task == "segmentation" or task == "enhancement" or task == "binarization": + if data_is_provided: + dir_train_flowing = os.path.join(dir_output, 'train') + dir_eval_flowing = os.path.join(dir_output, 'eval') + + + dir_flow_train_imgs = os.path.join(dir_train_flowing, 'images') + dir_flow_train_labels = os.path.join(dir_train_flowing, 'labels') + + dir_flow_eval_imgs = os.path.join(dir_eval_flowing, 'images') + dir_flow_eval_labels = os.path.join(dir_eval_flowing, 'labels') + + configuration() + + else: + dir_img, dir_seg = get_dirs_or_files(dir_train) + dir_img_val, dir_seg_val = get_dirs_or_files(dir_eval) + + # make first a directory in output for both training and evaluations in order to flow data from these directories. + dir_train_flowing = os.path.join(dir_output, 'train') + dir_eval_flowing = os.path.join(dir_output, 'eval') + + dir_flow_train_imgs = os.path.join(dir_train_flowing, 'images/') + dir_flow_train_labels = os.path.join(dir_train_flowing, 'labels/') + + dir_flow_eval_imgs = os.path.join(dir_eval_flowing, 'images/') + dir_flow_eval_labels = os.path.join(dir_eval_flowing, 'labels/') + + if os.path.isdir(dir_train_flowing): + os.system('rm -rf ' + dir_train_flowing) + os.makedirs(dir_train_flowing) + else: + os.makedirs(dir_train_flowing) + + if os.path.isdir(dir_eval_flowing): + os.system('rm -rf ' + dir_eval_flowing) + os.makedirs(dir_eval_flowing) + else: + os.makedirs(dir_eval_flowing) + + os.mkdir(dir_flow_train_imgs) + os.mkdir(dir_flow_train_labels) + + os.mkdir(dir_flow_eval_imgs) + os.mkdir(dir_flow_eval_labels) + + # set the gpu configuration + configuration() + + imgs_list=np.array(os.listdir(dir_img)) + segs_list=np.array(os.listdir(dir_seg)) + + imgs_list_test=np.array(os.listdir(dir_img_val)) + segs_list_test=np.array(os.listdir(dir_seg_val)) + + # writing patches into a sub-folder in order to be flowed from directory. + provide_patches(imgs_list, segs_list, dir_img, dir_seg, dir_flow_train_imgs, + dir_flow_train_labels, input_height, input_width, blur_k, + blur_aug, padding_white, padding_black, flip_aug, binarization, adding_rgb_background,adding_rgb_foreground, add_red_textlines, channels_shuffling, + scaling, shifting, degrading, brightening, scales, degrade_scales, brightness, + flip_index,shuffle_indexes, scaling_bluring, scaling_brightness, scaling_binarization, + rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=augmentation, + patches=patches, dir_img_bin=dir_img_bin,number_of_backgrounds_per_image=number_of_backgrounds_per_image,list_all_possible_background_images=list_all_possible_background_images, dir_rgb_backgrounds=dir_rgb_backgrounds, dir_rgb_foregrounds=dir_rgb_foregrounds,list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs) + + provide_patches(imgs_list_test, segs_list_test, dir_img_val, dir_seg_val, + dir_flow_eval_imgs, dir_flow_eval_labels, input_height, input_width, + blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization, adding_rgb_background, adding_rgb_foreground, add_red_textlines, channels_shuffling, + scaling, shifting, degrading, brightening, scales, degrade_scales, brightness, + flip_index, shuffle_indexes, scaling_bluring, scaling_brightness, scaling_binarization, + rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=False, patches=patches,dir_img_bin=dir_img_bin,number_of_backgrounds_per_image=number_of_backgrounds_per_image,list_all_possible_background_images=list_all_possible_background_images, dir_rgb_backgrounds=dir_rgb_backgrounds,dir_rgb_foregrounds=dir_rgb_foregrounds,list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs ) + + if weighted_loss: + weights = np.zeros(n_classes) + if data_is_provided: + for obj in os.listdir(dir_flow_train_labels): + try: + label_obj = cv2.imread(dir_flow_train_labels + '/' + obj) + label_obj_one_hot = get_one_hot(label_obj, label_obj.shape[0], label_obj.shape[1], n_classes) + weights += (label_obj_one_hot.sum(axis=0)).sum(axis=0) + except: + pass + else: + + for obj in os.listdir(dir_seg): + try: + label_obj = cv2.imread(dir_seg + '/' + obj) + label_obj_one_hot = get_one_hot(label_obj, label_obj.shape[0], label_obj.shape[1], n_classes) + weights += (label_obj_one_hot.sum(axis=0)).sum(axis=0) + except: + pass + + weights = 1.00 / weights + + weights = weights / float(np.sum(weights)) + weights = weights / float(np.min(weights)) + weights = weights / float(np.sum(weights)) + + if continue_training: + if backbone_type=='nontransformer': + if is_loss_soft_dice and (task == "segmentation" or task == "binarization"): + model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss}) + if weighted_loss and (task == "segmentation" or task == "binarization"): + model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)}) + if not is_loss_soft_dice and not weighted_loss: + model = load_model(dir_of_start_model , compile=True) + elif backbone_type=='transformer': + if is_loss_soft_dice and (task == "segmentation" or task == "binarization"): + model = load_model(dir_of_start_model, compile=True, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches,'soft_dice_loss': soft_dice_loss}) + if weighted_loss and (task == "segmentation" or task == "binarization"): + model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)}) + if not is_loss_soft_dice and not weighted_loss: + model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches}) + else: + index_start = 0 + if backbone_type=='nontransformer': + model = resnet50_unet(n_classes, input_height, input_width, task, weight_decay, pretraining) + elif backbone_type=='transformer': + num_patches_x = transformer_num_patches_xy[0] + num_patches_y = transformer_num_patches_xy[1] + num_patches = num_patches_x * num_patches_y + + if transformer_cnn_first: + if (input_height != (num_patches_y * transformer_patchsize_y * 32) ): + print("Error: transformer_patchsize_y or transformer_num_patches_xy height value error . input_height should be equal to ( transformer_num_patches_xy height value * transformer_patchsize_y * 32)") + sys.exit(1) + if (input_width != (num_patches_x * transformer_patchsize_x * 32) ): + print("Error: transformer_patchsize_x or transformer_num_patches_xy width value error . input_width should be equal to ( transformer_num_patches_xy width value * transformer_patchsize_x * 32)") + sys.exit(1) + if (transformer_projection_dim % (transformer_patchsize_y * transformer_patchsize_x)) != 0: + print("Error: transformer_projection_dim error. The remainder when parameter transformer_projection_dim is divided by (transformer_patchsize_y*transformer_patchsize_x) should be zero") + sys.exit(1) + + + model = vit_resnet50_unet(n_classes, transformer_patchsize_x, transformer_patchsize_y, num_patches, transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_projection_dim, input_height, input_width, task, weight_decay, pretraining) + else: + if (input_height != (num_patches_y * transformer_patchsize_y) ): + print("Error: transformer_patchsize_y or transformer_num_patches_xy height value error . input_height should be equal to ( transformer_num_patches_xy height value * transformer_patchsize_y)") + sys.exit(1) + if (input_width != (num_patches_x * transformer_patchsize_x) ): + print("Error: transformer_patchsize_x or transformer_num_patches_xy width value error . input_width should be equal to ( transformer_num_patches_xy width value * transformer_patchsize_x)") + sys.exit(1) + if (transformer_projection_dim % (transformer_patchsize_y * transformer_patchsize_x)) != 0: + print("Error: transformer_projection_dim error. The remainder when parameter transformer_projection_dim is divided by (transformer_patchsize_y*transformer_patchsize_x) should be zero") + sys.exit(1) + model = vit_resnet50_unet_transformer_before_cnn(n_classes, transformer_patchsize_x, transformer_patchsize_y, num_patches, transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_projection_dim, input_height, input_width, task, weight_decay, pretraining) + + #if you want to see the model structure just uncomment model summary. + model.summary() + + + if (task == "segmentation" or task == "binarization"): + if not is_loss_soft_dice and not weighted_loss: + model.compile(loss='categorical_crossentropy', + optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy']) + if is_loss_soft_dice: + model.compile(loss=soft_dice_loss, + optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy']) + if weighted_loss: + model.compile(loss=weighted_categorical_crossentropy(weights), + optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy']) + elif task == "enhancement": + model.compile(loss='mean_squared_error', + optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy']) + + + # generating train and evaluation data + train_gen = data_gen(dir_flow_train_imgs, dir_flow_train_labels, batch_size=n_batch, + input_height=input_height, input_width=input_width, n_classes=n_classes, task=task) + val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, batch_size=n_batch, + input_height=input_height, input_width=input_width, n_classes=n_classes, task=task) + + ##img_validation_patches = os.listdir(dir_flow_eval_imgs) + ##score_best=[] + ##score_best.append(0) + + if save_interval: + save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config) + + + for i in tqdm(range(index_start, n_epochs + index_start)): + if save_interval: + model.fit( + train_gen, + steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1, + validation_data=val_gen, + validation_steps=1, + epochs=1, callbacks=[save_weights_callback]) + else: + model.fit( + train_gen, + steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1, + validation_data=val_gen, + validation_steps=1, + epochs=1) + + model.save(os.path.join(dir_output,'model_'+str(i))) + + with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp: + json.dump(_config, fp) # encode dict into JSON + + #os.system('rm -rf '+dir_train_flowing) + #os.system('rm -rf '+dir_eval_flowing) + + #model.save(dir_output+'/'+'model'+'.h5') + elif task=='classification': + configuration() + model = resnet50_classifier(n_classes, input_height, input_width, weight_decay, pretraining) + + opt_adam = Adam(learning_rate=0.001) + model.compile(loss='categorical_crossentropy', + optimizer = opt_adam,metrics=['accuracy']) + + + list_classes = list(classification_classes_name.values()) + testX, testY = generate_data_from_folder_evaluation(dir_eval, input_height, input_width, n_classes, list_classes) + + y_tot=np.zeros((testX.shape[0],n_classes)) + + score_best=[] + score_best.append(0) + + num_rows = return_number_of_total_training_data(dir_train) + weights=[] + + for i in range(n_epochs): + history = model.fit( generate_data_from_folder_training(dir_train, n_batch , input_height, input_width, n_classes, list_classes), steps_per_epoch=num_rows / n_batch, verbose=1)#,class_weight=weights) + + y_pr_class = [] + for jj in range(testY.shape[0]): + y_pr=model.predict(testX[jj,:,:,:].reshape(1,input_height,input_width,3), verbose=0) + y_pr_ind= np.argmax(y_pr,axis=1) + y_pr_class.append(y_pr_ind) + + y_pr_class = np.array(y_pr_class) + f1score=f1_score(np.argmax(testY,axis=1), y_pr_class, average='macro') + print(i,f1score) + + if f1score>score_best[0]: + score_best[0]=f1score + model.save(os.path.join(dir_output,'model_best')) + + if f1score > f1_threshold_classification: + weights.append(model.get_weights() ) + + + if len(weights) >= 1: + new_weights=list() + for weights_list_tuple in zip(*weights): + new_weights.append( [np.array(weights_).mean(axis=0) for weights_ in zip(*weights_list_tuple)] ) + + new_weights = [np.array(x) for x in new_weights] + model_weight_averaged=tf.keras.models.clone_model(model) + model_weight_averaged.set_weights(new_weights) + + model_weight_averaged.save(os.path.join(dir_output,'model_ens_avg')) + with open(os.path.join( os.path.join(dir_output,'model_ens_avg'), "config.json"), "w") as fp: + json.dump(_config, fp) # encode dict into JSON + + with open(os.path.join( os.path.join(dir_output,'model_best'), "config.json"), "w") as fp: + json.dump(_config, fp) # encode dict into JSON + + elif task=='reading_order': + configuration() + model = machine_based_reading_order_model(n_classes,input_height,input_width,weight_decay,pretraining) + + dir_flow_train_imgs = os.path.join(dir_train, 'images') + dir_flow_train_labels = os.path.join(dir_train, 'labels') + + classes = os.listdir(dir_flow_train_labels) + if augmentation: + num_rows = len(classes)*(len(thetha) + 1) + else: + num_rows = len(classes) + #ls_test = os.listdir(dir_flow_train_labels) + + #f1score_tot = [0] + indexer_start = 0 + opt = SGD(learning_rate=0.01, momentum=0.9) + opt_adam = tf.keras.optimizers.Adam(learning_rate=0.0001) + model.compile(loss="binary_crossentropy", + optimizer = opt_adam,metrics=['accuracy']) + + if save_interval: + save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config) + + for i in range(n_epochs): + if save_interval: + history = model.fit(generate_arrays_from_folder_reading_order(dir_flow_train_labels, dir_flow_train_imgs, n_batch, input_height, input_width, n_classes, thetha, augmentation), steps_per_epoch=num_rows / n_batch, verbose=1, callbacks=[save_weights_callback]) + else: + history = model.fit(generate_arrays_from_folder_reading_order(dir_flow_train_labels, dir_flow_train_imgs, n_batch, input_height, input_width, n_classes, thetha, augmentation), steps_per_epoch=num_rows / n_batch, verbose=1) + model.save( os.path.join(dir_output,'model_'+str(i+indexer_start) )) + + with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp: + json.dump(_config, fp) # encode dict into JSON + ''' + if f1score>f1score_tot[0]: + f1score_tot[0] = f1score + model_dir = os.path.join(dir_out,'model_best') + model.save(model_dir) + ''' + + diff --git a/train/utils.py b/train/utils.py new file mode 100644 index 0000000..2bb7261 --- /dev/null +++ b/train/utils.py @@ -0,0 +1,1056 @@ +import os +import cv2 +import numpy as np +import seaborn as sns +from scipy.ndimage.interpolation import map_coordinates +from scipy.ndimage.filters import gaussian_filter +import random +from tqdm import tqdm +import imutils +import math +from tensorflow.keras.utils import to_categorical +from PIL import Image, ImageEnhance + + +def return_shuffled_channels(img, channels_order): + """ + channels order in ordinary case is like this [0, 1, 2]. In the case of shuffling the order should be provided. + """ + img_sh = np.copy(img) + + img_sh[:,:,0]= img[:,:,channels_order[0]] + img_sh[:,:,1]= img[:,:,channels_order[1]] + img_sh[:,:,2]= img[:,:,channels_order[2]] + return img_sh + +def return_binary_image_with_red_textlines(img_bin): + img_red = np.copy(img_bin) + + img_red[:,:,0][img_bin[:,:,0] == 0] = 255 + return img_red + +def return_binary_image_with_given_rgb_background(img_bin, img_rgb_background): + img_rgb_background = resize_image(img_rgb_background ,img_bin.shape[0], img_bin.shape[1]) + + img_final = np.copy(img_bin) + + img_final[:,:,0][img_bin[:,:,0] != 0] = img_rgb_background[:,:,0][img_bin[:,:,0] != 0] + img_final[:,:,1][img_bin[:,:,1] != 0] = img_rgb_background[:,:,1][img_bin[:,:,1] != 0] + img_final[:,:,2][img_bin[:,:,2] != 0] = img_rgb_background[:,:,2][img_bin[:,:,2] != 0] + + return img_final + +def return_binary_image_with_given_rgb_background_and_given_foreground_rgb(img_bin, img_rgb_background, rgb_foreground): + img_rgb_background = resize_image(img_rgb_background ,img_bin.shape[0], img_bin.shape[1]) + + img_final = np.copy(img_bin) + img_foreground = np.zeros(img_bin.shape) + + + img_foreground[:,:,0][img_bin[:,:,0] == 0] = rgb_foreground[0] + img_foreground[:,:,1][img_bin[:,:,0] == 0] = rgb_foreground[1] + img_foreground[:,:,2][img_bin[:,:,0] == 0] = rgb_foreground[2] + + + img_final[:,:,0][img_bin[:,:,0] != 0] = img_rgb_background[:,:,0][img_bin[:,:,0] != 0] + img_final[:,:,1][img_bin[:,:,1] != 0] = img_rgb_background[:,:,1][img_bin[:,:,1] != 0] + img_final[:,:,2][img_bin[:,:,2] != 0] = img_rgb_background[:,:,2][img_bin[:,:,2] != 0] + + img_final = img_final + img_foreground + return img_final + +def return_binary_image_with_given_rgb_background_red_textlines(img_bin, img_rgb_background, img_color): + img_rgb_background = resize_image(img_rgb_background ,img_bin.shape[0], img_bin.shape[1]) + + img_final = np.copy(img_color) + + img_final[:,:,0][img_bin[:,:,0] != 0] = img_rgb_background[:,:,0][img_bin[:,:,0] != 0] + img_final[:,:,1][img_bin[:,:,1] != 0] = img_rgb_background[:,:,1][img_bin[:,:,1] != 0] + img_final[:,:,2][img_bin[:,:,2] != 0] = img_rgb_background[:,:,2][img_bin[:,:,2] != 0] + + return img_final + +def return_image_with_red_elements(img, img_bin): + img_final = np.copy(img) + + img_final[:,:,0][img_bin[:,:,0]==0] = 0 + img_final[:,:,1][img_bin[:,:,0]==0] = 0 + img_final[:,:,2][img_bin[:,:,0]==0] = 255 + return img_final + +def shift_image_and_label(img, label, type_shift): + h_n = int(img.shape[0]*1.06) + w_n = int(img.shape[1]*1.06) + + channel0_avg = int( np.mean(img[:,:,0]) ) + channel1_avg = int( np.mean(img[:,:,1]) ) + channel2_avg = int( np.mean(img[:,:,2]) ) + + h_diff = abs( img.shape[0] - h_n ) + w_diff = abs( img.shape[1] - w_n ) + + h_start = int(h_diff / 2.) + w_start = int(w_diff / 2.) + + img_scaled_padded = np.zeros((h_n, w_n, 3)) + label_scaled_padded = np.zeros((h_n, w_n, 3)) + + img_scaled_padded[:,:,0] = channel0_avg + img_scaled_padded[:,:,1] = channel1_avg + img_scaled_padded[:,:,2] = channel2_avg + + img_scaled_padded[h_start:h_start+img.shape[0], w_start:w_start+img.shape[1],:] = img[:,:,:] + label_scaled_padded[h_start:h_start+img.shape[0], w_start:w_start+img.shape[1],:] = label[:,:,:] + + + if type_shift=="xpos": + img_dis = img_scaled_padded[h_start:h_start+img.shape[0],2*w_start:2*w_start+img.shape[1],:] + label_dis = label_scaled_padded[h_start:h_start+img.shape[0],2*w_start:2*w_start+img.shape[1],:] + elif type_shift=="xmin": + img_dis = img_scaled_padded[h_start:h_start+img.shape[0],:img.shape[1],:] + label_dis = label_scaled_padded[h_start:h_start+img.shape[0],:img.shape[1],:] + elif type_shift=="ypos": + img_dis = img_scaled_padded[2*h_start:2*h_start+img.shape[0],w_start:w_start+img.shape[1],:] + label_dis = label_scaled_padded[2*h_start:2*h_start+img.shape[0],w_start:w_start+img.shape[1],:] + elif type_shift=="ymin": + img_dis = img_scaled_padded[:img.shape[0],w_start:w_start+img.shape[1],:] + label_dis = label_scaled_padded[:img.shape[0],w_start:w_start+img.shape[1],:] + elif type_shift=="xypos": + img_dis = img_scaled_padded[2*h_start:2*h_start+img.shape[0],2*w_start:2*w_start+img.shape[1],:] + label_dis = label_scaled_padded[2*h_start:2*h_start+img.shape[0],2*w_start:2*w_start+img.shape[1],:] + elif type_shift=="xymin": + img_dis = img_scaled_padded[:img.shape[0],:img.shape[1],:] + label_dis = label_scaled_padded[:img.shape[0],:img.shape[1],:] + return img_dis, label_dis + +def scale_image_for_no_patch(img, label, scale): + h_n = int(img.shape[0]*scale) + w_n = int(img.shape[1]*scale) + + channel0_avg = int( np.mean(img[:,:,0]) ) + channel1_avg = int( np.mean(img[:,:,1]) ) + channel2_avg = int( np.mean(img[:,:,2]) ) + + h_diff = img.shape[0] - h_n + w_diff = img.shape[1] - w_n + + h_start = int(h_diff / 2.) + w_start = int(w_diff / 2.) + + img_res = resize_image(img, h_n, w_n) + label_res = resize_image(label, h_n, w_n) + + img_scaled_padded = np.copy(img) + + label_scaled_padded = np.zeros(label.shape) + + img_scaled_padded[:,:,0] = channel0_avg + img_scaled_padded[:,:,1] = channel1_avg + img_scaled_padded[:,:,2] = channel2_avg + + img_scaled_padded[h_start:h_start+h_n, w_start:w_start+w_n,:] = img_res[:,:,:] + label_scaled_padded[h_start:h_start+h_n, w_start:w_start+w_n,:] = label_res[:,:,:] + + return img_scaled_padded, label_scaled_padded + + +def return_number_of_total_training_data(path_classes): + sub_classes = os.listdir(path_classes) + n_tot = 0 + for sub_c in sub_classes: + sub_files = os.listdir(os.path.join(path_classes,sub_c)) + n_tot = n_tot + len(sub_files) + return n_tot + + + +def generate_data_from_folder_evaluation(path_classes, height, width, n_classes, list_classes): + #sub_classes = os.listdir(path_classes) + #n_classes = len(sub_classes) + all_imgs = [] + labels = [] + #dicts =dict() + #indexer= 0 + for indexer, sub_c in enumerate(list_classes): + sub_files = os.listdir(os.path.join(path_classes,sub_c )) + sub_files = [os.path.join(path_classes,sub_c )+'/' + x for x in sub_files] + #print( os.listdir(os.path.join(path_classes,sub_c )) ) + all_imgs = all_imgs + sub_files + sub_labels = list( np.zeros( len(sub_files) ) +indexer ) + + #print( len(sub_labels) ) + labels = labels + sub_labels + #dicts[sub_c] = indexer + #indexer +=1 + + + categories = to_categorical(range(n_classes)).astype(np.int16)#[ [1 , 0, 0 , 0 , 0 , 0] , [0 , 1, 0 , 0 , 0 , 0] , [0 , 0, 1 , 0 , 0 , 0] , [0 , 0, 0 , 1 , 0 , 0] , [0 , 0, 0 , 0 , 1 , 0] , [0 , 0, 0 , 0 , 0 , 1] ] + ret_x= np.zeros((len(labels), height,width, 3)).astype(np.int16) + ret_y= np.zeros((len(labels), n_classes)).astype(np.int16) + + #print(all_imgs) + for i in range(len(all_imgs)): + row = all_imgs[i] + #####img = cv2.imread(row, 0) + #####img= resize_image (img, height, width) + #####img = img.astype(np.uint16) + #####ret_x[i, :,:,0] = img[:,:] + #####ret_x[i, :,:,1] = img[:,:] + #####ret_x[i, :,:,2] = img[:,:] + + img = cv2.imread(row) + img= resize_image (img, height, width) + img = img.astype(np.uint16) + ret_x[i, :,:] = img[:,:,:] + + ret_y[i, :] = categories[ int( labels[i] ) ][:] + + return ret_x/255., ret_y + +def generate_data_from_folder_training(path_classes, batchsize, height, width, n_classes, list_classes): + #sub_classes = os.listdir(path_classes) + #n_classes = len(sub_classes) + + all_imgs = [] + labels = [] + #dicts =dict() + #indexer= 0 + for indexer, sub_c in enumerate(list_classes): + sub_files = os.listdir(os.path.join(path_classes,sub_c )) + sub_files = [os.path.join(path_classes,sub_c )+'/' + x for x in sub_files] + #print( os.listdir(os.path.join(path_classes,sub_c )) ) + all_imgs = all_imgs + sub_files + sub_labels = list( np.zeros( len(sub_files) ) +indexer ) + + #print( len(sub_labels) ) + labels = labels + sub_labels + #dicts[sub_c] = indexer + #indexer +=1 + + ids = np.array(range(len(labels))) + random.shuffle(ids) + + shuffled_labels = np.array(labels)[ids] + shuffled_files = np.array(all_imgs)[ids] + categories = to_categorical(range(n_classes)).astype(np.int16)#[ [1 , 0, 0 , 0 , 0 , 0] , [0 , 1, 0 , 0 , 0 , 0] , [0 , 0, 1 , 0 , 0 , 0] , [0 , 0, 0 , 1 , 0 , 0] , [0 , 0, 0 , 0 , 1 , 0] , [0 , 0, 0 , 0 , 0 , 1] ] + ret_x= np.zeros((batchsize, height,width, 3)).astype(np.int16) + ret_y= np.zeros((batchsize, n_classes)).astype(np.int16) + batchcount = 0 + while True: + for i in range(len(shuffled_files)): + row = shuffled_files[i] + #print(row) + ###img = cv2.imread(row, 0) + ###img= resize_image (img, height, width) + ###img = img.astype(np.uint16) + ###ret_x[batchcount, :,:,0] = img[:,:] + ###ret_x[batchcount, :,:,1] = img[:,:] + ###ret_x[batchcount, :,:,2] = img[:,:] + + img = cv2.imread(row) + img= resize_image (img, height, width) + img = img.astype(np.uint16) + ret_x[batchcount, :,:,:] = img[:,:,:] + + #print(int(shuffled_labels[i]) ) + #print( categories[int(shuffled_labels[i])] ) + ret_y[batchcount, :] = categories[ int( shuffled_labels[i] ) ][:] + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield (ret_x, ret_y) + ret_x= np.zeros((batchsize, height,width, 3)).astype(np.int16) + ret_y= np.zeros((batchsize, n_classes)).astype(np.int16) + batchcount = 0 + +def do_brightening(img_in_dir, factor): + im = Image.open(img_in_dir) + enhancer = ImageEnhance.Brightness(im) + out_img = enhancer.enhance(factor) + out_img = out_img.convert('RGB') + opencv_img = np.array(out_img) + opencv_img = opencv_img[:,:,::-1].copy() + return opencv_img + + +def bluring(img_in, kind): + if kind == 'gauss': + img_blur = cv2.GaussianBlur(img_in, (5, 5), 0) + elif kind == "median": + img_blur = cv2.medianBlur(img_in, 5) + elif kind == 'blur': + img_blur = cv2.blur(img_in, (5, 5)) + return img_blur + + +def elastic_transform(image, alpha, sigma, seedj, random_state=None): + """Elastic deformation of images as described in [Simard2003]_. + .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for + Convolutional Neural Networks applied to Visual Document Analysis", in + Proc. of the International Conference on Document Analysis and + Recognition, 2003. + """ + if random_state is None: + random_state = np.random.RandomState(seedj) + + shape = image.shape + dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha + dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha + dz = np.zeros_like(dx) + + x, y, z = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]), np.arange(shape[2])) + indices = np.reshape(y + dy, (-1, 1)), np.reshape(x + dx, (-1, 1)), np.reshape(z, (-1, 1)) + + distored_image = map_coordinates(image, indices, order=1, mode='reflect') + return distored_image.reshape(image.shape) + + +def rotation_90(img): + img_rot = np.zeros((img.shape[1], img.shape[0], img.shape[2])) + img_rot[:, :, 0] = img[:, :, 0].T + img_rot[:, :, 1] = img[:, :, 1].T + img_rot[:, :, 2] = img[:, :, 2].T + return img_rot + + +def rotatedRectWithMaxArea(w, h, angle): + """ + Given a rectangle of size wxh that has been rotated by 'angle' (in + radians), computes the width and height of the largest possible + axis-aligned rectangle (maximal area) within the rotated rectangle. + """ + if w <= 0 or h <= 0: + return 0, 0 + + width_is_longer = w >= h + side_long, side_short = (w, h) if width_is_longer else (h, w) + + # since the solutions for angle, -angle and 180-angle are all the same, + # if suffices to look at the first quadrant and the absolute values of sin,cos: + sin_a, cos_a = abs(math.sin(angle)), abs(math.cos(angle)) + if side_short <= 2. * sin_a * cos_a * side_long or abs(sin_a - cos_a) < 1e-10: + # half constrained case: two crop corners touch the longer side, + # the other two corners are on the mid-line parallel to the longer line + x = 0.5 * side_short + wr, hr = (x / sin_a, x / cos_a) if width_is_longer else (x / cos_a, x / sin_a) + else: + # fully constrained case: crop touches all 4 sides + cos_2a = cos_a * cos_a - sin_a * sin_a + wr, hr = (w * cos_a - h * sin_a) / cos_2a, (h * cos_a - w * sin_a) / cos_2a + + return wr, hr + + +def rotate_max_area(image, rotated, rotated_label, angle): + """ image: cv2 image matrix object + angle: in degree + """ + wr, hr = rotatedRectWithMaxArea(image.shape[1], image.shape[0], + math.radians(angle)) + h, w, _ = rotated.shape + y1 = h // 2 - int(hr / 2) + y2 = y1 + int(hr) + x1 = w // 2 - int(wr / 2) + x2 = x1 + int(wr) + return rotated[y1:y2, x1:x2], rotated_label[y1:y2, x1:x2] + +def rotate_max_area_single_image(image, rotated, angle): + """ image: cv2 image matrix object + angle: in degree + """ + wr, hr = rotatedRectWithMaxArea(image.shape[1], image.shape[0], + math.radians(angle)) + h, w, _ = rotated.shape + y1 = h // 2 - int(hr / 2) + y2 = y1 + int(hr) + x1 = w // 2 - int(wr / 2) + x2 = x1 + int(wr) + return rotated[y1:y2, x1:x2] + +def rotation_not_90_func(img, label, thetha): + rotated = imutils.rotate(img, thetha) + rotated_label = imutils.rotate(label, thetha) + return rotate_max_area(img, rotated, rotated_label, thetha) + + +def rotation_not_90_func_single_image(img, thetha): + rotated = imutils.rotate(img, thetha) + return rotate_max_area_single_image(img, rotated, thetha) + + +def color_images(seg, n_classes): + ann_u = range(n_classes) + if len(np.shape(seg)) == 3: + seg = seg[:, :, 0] + + seg_img = np.zeros((np.shape(seg)[0], np.shape(seg)[1], 3)).astype(float) + colors = sns.color_palette("hls", n_classes) + + for c in ann_u: + c = int(c) + segl = (seg == c) + seg_img[:, :, 0] += segl * (colors[c][0]) + seg_img[:, :, 1] += segl * (colors[c][1]) + seg_img[:, :, 2] += segl * (colors[c][2]) + return seg_img + + +def resize_image(seg_in, input_height, input_width): + return cv2.resize(seg_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST) + + +def get_one_hot(seg, input_height, input_width, n_classes): + seg = seg[:, :, 0] + seg_f = np.zeros((input_height, input_width, n_classes)) + for j in range(n_classes): + seg_f[:, :, j] = (seg == j).astype(int) + return seg_f + + +def IoU(Yi, y_predi): + ## mean Intersection over Union + ## Mean IoU = TP/(FN + TP + FP) + + IoUs = [] + classes_true = np.unique(Yi) + for c in classes_true: + TP = np.sum((Yi == c) & (y_predi == c)) + FP = np.sum((Yi != c) & (y_predi == c)) + FN = np.sum((Yi == c) & (y_predi != c)) + IoU = TP / float(TP + FP + FN) + #print("class {:02.0f}: #TP={:6.0f}, #FP={:6.0f}, #FN={:5.0f}, IoU={:4.3f}".format(c, TP, FP, FN, IoU)) + IoUs.append(IoU) + mIoU = np.mean(IoUs) + #print("_________________") + #print("Mean IoU: {:4.3f}".format(mIoU)) + return mIoU + +def generate_arrays_from_folder_reading_order(classes_file_dir, modal_dir, batchsize, height, width, n_classes, thetha, augmentation=False): + all_labels_files = os.listdir(classes_file_dir) + ret_x= np.zeros((batchsize, height, width, 3))#.astype(np.int16) + ret_y= np.zeros((batchsize, n_classes)).astype(np.int16) + batchcount = 0 + while True: + for i in all_labels_files: + file_name = os.path.splitext(i)[0] + img = cv2.imread(os.path.join(modal_dir,file_name+'.png')) + + label_class = int( np.load(os.path.join(classes_file_dir,i)) ) + + ret_x[batchcount, :,:,0] = img[:,:,0]/3.0 + ret_x[batchcount, :,:,2] = img[:,:,2]/3.0 + ret_x[batchcount, :,:,1] = img[:,:,1]/5.0 + + ret_y[batchcount, :] = label_class + batchcount+=1 + if batchcount>=batchsize: + yield (ret_x, ret_y) + ret_x= np.zeros((batchsize, height, width, 3))#.astype(np.int16) + ret_y= np.zeros((batchsize, n_classes)).astype(np.int16) + batchcount = 0 + + if augmentation: + for thetha_i in thetha: + img_rot = rotation_not_90_func_single_image(img, thetha_i) + + img_rot = resize_image(img_rot, height, width) + + ret_x[batchcount, :,:,0] = img_rot[:,:,0]/3.0 + ret_x[batchcount, :,:,2] = img_rot[:,:,2]/3.0 + ret_x[batchcount, :,:,1] = img_rot[:,:,1]/5.0 + + ret_y[batchcount, :] = label_class + batchcount+=1 + if batchcount>=batchsize: + yield (ret_x, ret_y) + ret_x= np.zeros((batchsize, height, width, 3))#.astype(np.int16) + ret_y= np.zeros((batchsize, n_classes)).astype(np.int16) + batchcount = 0 + +def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_classes, task='segmentation'): + c = 0 + n = [f for f in os.listdir(img_folder) if not f.startswith('.')] # os.listdir(img_folder) #List of training images + random.shuffle(n) + while True: + img = np.zeros((batch_size, input_height, input_width, 3)).astype('float') + mask = np.zeros((batch_size, input_height, input_width, n_classes)).astype('float') + + for i in range(c, c + batch_size): # initially from 0 to 16, c = 0. + try: + filename = os.path.splitext(n[i])[0] + + train_img = cv2.imread(img_folder + '/' + n[i]) / 255. + train_img = cv2.resize(train_img, (input_width, input_height), + interpolation=cv2.INTER_NEAREST) # Read an image from folder and resize + + img[i - c] = train_img # add to array - img[0], img[1], and so on. + if task == "segmentation" or task=="binarization": + train_mask = cv2.imread(mask_folder + '/' + filename + '.png') + train_mask = get_one_hot(resize_image(train_mask, input_height, input_width), input_height, input_width, + n_classes) + elif task == "enhancement": + train_mask = cv2.imread(mask_folder + '/' + filename + '.png')/255. + train_mask = resize_image(train_mask, input_height, input_width) + + # train_mask = train_mask.reshape(224, 224, 1) # Add extra dimension for parity with train_img size [512 * 512 * 3] + + mask[i - c] = train_mask + except: + img[i - c] = np.ones((input_height, input_width, 3)).astype('float') + mask[i - c] = np.zeros((input_height, input_width, n_classes)).astype('float') + + c += batch_size + if c + batch_size >= len(os.listdir(img_folder)): + c = 0 + random.shuffle(n) + yield img, mask + + +def otsu_copy(img): + img_r = np.zeros(img.shape) + img1 = img[:, :, 0] + img2 = img[:, :, 1] + img3 = img[:, :, 2] + _, threshold1 = cv2.threshold(img1, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + _, threshold2 = cv2.threshold(img2, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + _, threshold3 = cv2.threshold(img3, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + img_r[:, :, 0] = threshold1 + img_r[:, :, 1] = threshold1 + img_r[:, :, 2] = threshold1 + return img_r + + +def get_patches(dir_img_f, dir_seg_f, img, label, height, width, indexer): + if img.shape[0] < height or img.shape[1] < width: + img, label = do_padding(img, label, height, width) + + img_h = img.shape[0] + img_w = img.shape[1] + + nxf = img_w / float(width) + nyf = img_h / float(height) + + if nxf > int(nxf): + nxf = int(nxf) + 1 + if nyf > int(nyf): + nyf = int(nyf) + 1 + + nxf = int(nxf) + nyf = int(nyf) + + for i in range(nxf): + for j in range(nyf): + index_x_d = i * width + index_x_u = (i + 1) * width + + index_y_d = j * height + index_y_u = (j + 1) * height + + if index_x_u > img_w: + index_x_u = img_w + index_x_d = img_w - width + if index_y_u > img_h: + index_y_u = img_h + index_y_d = img_h - height + + img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :] + label_patch = label[index_y_d:index_y_u, index_x_d:index_x_u, :] + + cv2.imwrite(dir_img_f + '/img_' + str(indexer) + '.png', img_patch) + cv2.imwrite(dir_seg_f + '/img_' + str(indexer) + '.png', label_patch) + indexer += 1 + + return indexer + + +def do_padding_white(img): + img_org_h = img.shape[0] + img_org_w = img.shape[1] + + index_start_h = 4 + index_start_w = 4 + + img_padded = np.zeros((img.shape[0] + 2*index_start_h, img.shape[1]+ 2*index_start_w, img.shape[2])) + 255 + img_padded[index_start_h: index_start_h + img.shape[0], index_start_w: index_start_w + img.shape[1], :] = img[:, :, :] + + return img_padded.astype(float) + + +def do_degrading(img, scale): + img_org_h = img.shape[0] + img_org_w = img.shape[1] + + img_res = resize_image(img, int(img_org_h * scale), int(img_org_w * scale)) + + return resize_image(img_res, img_org_h, img_org_w) + + +def do_padding_black(img): + img_org_h = img.shape[0] + img_org_w = img.shape[1] + + index_start_h = 4 + index_start_w = 4 + + img_padded = np.zeros((img.shape[0] + 2*index_start_h, img.shape[1] + 2*index_start_w, img.shape[2])) + img_padded[index_start_h: index_start_h + img.shape[0], index_start_w: index_start_w + img.shape[1], :] = img[:, :, :] + + return img_padded.astype(float) + + +def do_padding_label(img): + img_org_h = img.shape[0] + img_org_w = img.shape[1] + + index_start_h = 4 + index_start_w = 4 + + img_padded = np.zeros((img.shape[0] + 2*index_start_h, img.shape[1] + 2*index_start_w, img.shape[2])) + img_padded[index_start_h: index_start_h + img.shape[0], index_start_w: index_start_w + img.shape[1], :] = img[:, :, :] + + return img_padded.astype(np.int16) + +def do_padding(img, label, height, width): + height_new=img.shape[0] + width_new=img.shape[1] + + h_start = 0 + w_start = 0 + + if img.shape[0] < height: + h_start = int(abs(height - img.shape[0]) / 2.) + height_new = height + + if img.shape[1] < width: + w_start = int(abs(width - img.shape[1]) / 2.) + width_new = width + + img_new = np.ones((height_new, width_new, img.shape[2])).astype(float) * 255 + label_new = np.zeros((height_new, width_new, label.shape[2])).astype(float) + + img_new[h_start:h_start + img.shape[0], w_start:w_start + img.shape[1], :] = np.copy(img[:, :, :]) + label_new[h_start:h_start + label.shape[0], w_start:w_start + label.shape[1], :] = np.copy(label[:, :, :]) + + return img_new,label_new + + +def get_patches_num_scale(dir_img_f, dir_seg_f, img, label, height, width, indexer, n_patches, scaler): + if img.shape[0] < height or img.shape[1] < width: + img, label = do_padding(img, label, height, width) + + img_h = img.shape[0] + img_w = img.shape[1] + + height_scale = int(height * scaler) + width_scale = int(width * scaler) + + + nxf = img_w / float(width_scale) + nyf = img_h / float(height_scale) + + if nxf > int(nxf): + nxf = int(nxf) + 1 + if nyf > int(nyf): + nyf = int(nyf) + 1 + + nxf = int(nxf) + nyf = int(nyf) + + for i in range(nxf): + for j in range(nyf): + index_x_d = i * width_scale + index_x_u = (i + 1) * width_scale + + index_y_d = j * height_scale + index_y_u = (j + 1) * height_scale + + if index_x_u > img_w: + index_x_u = img_w + index_x_d = img_w - width_scale + if index_y_u > img_h: + index_y_u = img_h + index_y_d = img_h - height_scale + + + img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :] + label_patch = label[index_y_d:index_y_u, index_x_d:index_x_u, :] + + img_patch = resize_image(img_patch, height, width) + label_patch = resize_image(label_patch, height, width) + + cv2.imwrite(dir_img_f + '/img_' + str(indexer) + '.png', img_patch) + cv2.imwrite(dir_seg_f + '/img_' + str(indexer) + '.png', label_patch) + indexer += 1 + + return indexer + + +def get_patches_num_scale_new(dir_img_f, dir_seg_f, img, label, height, width, indexer, scaler): + img = resize_image(img, int(img.shape[0] * scaler), int(img.shape[1] * scaler)) + label = resize_image(label, int(label.shape[0] * scaler), int(label.shape[1] * scaler)) + + if img.shape[0] < height or img.shape[1] < width: + img, label = do_padding(img, label, height, width) + + img_h = img.shape[0] + img_w = img.shape[1] + + height_scale = int(height * 1) + width_scale = int(width * 1) + + nxf = img_w / float(width_scale) + nyf = img_h / float(height_scale) + + if nxf > int(nxf): + nxf = int(nxf) + 1 + if nyf > int(nyf): + nyf = int(nyf) + 1 + + nxf = int(nxf) + nyf = int(nyf) + + for i in range(nxf): + for j in range(nyf): + index_x_d = i * width_scale + index_x_u = (i + 1) * width_scale + + index_y_d = j * height_scale + index_y_u = (j + 1) * height_scale + + if index_x_u > img_w: + index_x_u = img_w + index_x_d = img_w - width_scale + if index_y_u > img_h: + index_y_u = img_h + index_y_d = img_h - height_scale + + img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :] + label_patch = label[index_y_d:index_y_u, index_x_d:index_x_u, :] + + cv2.imwrite(dir_img_f + '/img_' + str(indexer) + '.png', img_patch) + cv2.imwrite(dir_seg_f + '/img_' + str(indexer) + '.png', label_patch) + indexer += 1 + + return indexer + + +def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow_train_imgs, + dir_flow_train_labels, input_height, input_width, blur_k, blur_aug, + padding_white, padding_black, flip_aug, binarization, adding_rgb_background, adding_rgb_foreground, add_red_textlines, channels_shuffling, scaling, shifting, degrading, + brightening, scales, degrade_scales, brightness, flip_index, shuffle_indexes, + scaling_bluring, scaling_brightness, scaling_binarization, rotation, + rotation_not_90, thetha, scaling_flip, task, augmentation=False, patches=False, dir_img_bin=None,number_of_backgrounds_per_image=None,list_all_possible_background_images=None, dir_rgb_backgrounds=None, dir_rgb_foregrounds=None, list_all_possible_foreground_rgbs=None): + + indexer = 0 + for im, seg_i in tqdm(zip(imgs_list_train, segs_list_train)): + img_name = os.path.splitext(im)[0] + if task == "segmentation" or task == "binarization": + dir_of_label_file = os.path.join(dir_seg, img_name + '.png') + elif task=="enhancement": + dir_of_label_file = os.path.join(dir_seg, im) + + if not patches: + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(cv2.imread(dir_img + '/' + im), input_height, input_width)) + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) + indexer += 1 + + if augmentation: + if flip_aug: + for f_i in flip_index: + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', + resize_image(cv2.flip(cv2.imread(dir_img+'/'+im),f_i),input_height,input_width) ) + + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', + resize_image(cv2.flip(cv2.imread(dir_of_label_file), f_i), input_height, input_width)) + indexer += 1 + + if blur_aug: + for blur_i in blur_k: + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', + (resize_image(bluring(cv2.imread(dir_img + '/' + im), blur_i), input_height, input_width))) + + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', + resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) + indexer += 1 + if brightening: + for factor in brightness: + try: + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', + (resize_image(do_brightening(dir_img + '/' +im, factor), input_height, input_width))) + + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', + resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) + indexer += 1 + except: + pass + + if binarization: + + if dir_img_bin: + img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') + + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', + resize_image(img_bin_corr, input_height, input_width)) + else: + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', + resize_image(otsu_copy(cv2.imread(dir_img + '/' + im)), input_height, input_width)) + + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', + resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) + indexer += 1 + + if degrading: + for degrade_scale_ind in degrade_scales: + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', + (resize_image(do_degrading(cv2.imread(dir_img + '/' + im), degrade_scale_ind), input_height, input_width))) + + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', + resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) + indexer += 1 + + if rotation_not_90: + for thetha_i in thetha: + img_max_rotated, label_max_rotated = rotation_not_90_func(cv2.imread(dir_img + '/'+im), + cv2.imread(dir_of_label_file), thetha_i) + + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(img_max_rotated, input_height, input_width)) + + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(label_max_rotated, input_height, input_width)) + indexer += 1 + + if channels_shuffling: + for shuffle_index in shuffle_indexes: + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', + (resize_image(return_shuffled_channels(cv2.imread(dir_img + '/' + im), shuffle_index), input_height, input_width))) + + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', + resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) + indexer += 1 + + if scaling: + for sc_ind in scales: + img_scaled, label_scaled = scale_image_for_no_patch(cv2.imread(dir_img + '/'+im), + cv2.imread(dir_of_label_file), sc_ind) + + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(img_scaled, input_height, input_width)) + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(label_scaled, input_height, input_width)) + indexer += 1 + if shifting: + shift_types = ['xpos', 'xmin', 'ypos', 'ymin', 'xypos', 'xymin'] + for st_ind in shift_types: + img_shifted, label_shifted = shift_image_and_label(cv2.imread(dir_img + '/'+im), + cv2.imread(dir_of_label_file), st_ind) + + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(img_shifted, input_height, input_width)) + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(label_shifted, input_height, input_width)) + indexer += 1 + + + if adding_rgb_background: + img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') + for i_n in range(number_of_backgrounds_per_image): + background_image_chosen_name = random.choice(list_all_possible_background_images) + img_rgb_background_chosen = cv2.imread(dir_rgb_backgrounds + '/' + background_image_chosen_name) + img_with_overlayed_background = return_binary_image_with_given_rgb_background(img_bin_corr, img_rgb_background_chosen) + + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(img_with_overlayed_background, input_height, input_width)) + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', + resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) + + indexer += 1 + + if adding_rgb_foreground: + img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') + for i_n in range(number_of_backgrounds_per_image): + background_image_chosen_name = random.choice(list_all_possible_background_images) + foreground_rgb_chosen_name = random.choice(list_all_possible_foreground_rgbs) + + img_rgb_background_chosen = cv2.imread(dir_rgb_backgrounds + '/' + background_image_chosen_name) + foreground_rgb_chosen = np.load(dir_rgb_foregrounds + '/' + foreground_rgb_chosen_name) + + img_with_overlayed_background = return_binary_image_with_given_rgb_background_and_given_foreground_rgb(img_bin_corr, img_rgb_background_chosen, foreground_rgb_chosen) + + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(img_with_overlayed_background, input_height, input_width)) + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', + resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) + + indexer += 1 + + if add_red_textlines: + img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') + img_red_context = return_image_with_red_elements(cv2.imread(dir_img + '/'+im), img_bin_corr) + + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(img_red_context, input_height, input_width)) + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', + resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) + + indexer += 1 + + + + + if patches: + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + cv2.imread(dir_img + '/' + im), cv2.imread(dir_of_label_file), + input_height, input_width, indexer=indexer) + + if augmentation: + if rotation: + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + rotation_90(cv2.imread(dir_img + '/' + im)), + rotation_90(cv2.imread(dir_of_label_file)), + input_height, input_width, indexer=indexer) + + if rotation_not_90: + for thetha_i in thetha: + img_max_rotated, label_max_rotated = rotation_not_90_func(cv2.imread(dir_img + '/'+im), + cv2.imread(dir_of_label_file), thetha_i) + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + img_max_rotated, + label_max_rotated, + input_height, input_width, indexer=indexer) + + if channels_shuffling: + for shuffle_index in shuffle_indexes: + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + return_shuffled_channels(cv2.imread(dir_img + '/' + im), shuffle_index), + cv2.imread(dir_of_label_file), + input_height, input_width, indexer=indexer) + + if adding_rgb_background: + img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') + for i_n in range(number_of_backgrounds_per_image): + background_image_chosen_name = random.choice(list_all_possible_background_images) + img_rgb_background_chosen = cv2.imread(dir_rgb_backgrounds + '/' + background_image_chosen_name) + img_with_overlayed_background = return_binary_image_with_given_rgb_background(img_bin_corr, img_rgb_background_chosen) + + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + img_with_overlayed_background, + cv2.imread(dir_of_label_file), + input_height, input_width, indexer=indexer) + + + if adding_rgb_foreground: + img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') + for i_n in range(number_of_backgrounds_per_image): + background_image_chosen_name = random.choice(list_all_possible_background_images) + foreground_rgb_chosen_name = random.choice(list_all_possible_foreground_rgbs) + + img_rgb_background_chosen = cv2.imread(dir_rgb_backgrounds + '/' + background_image_chosen_name) + foreground_rgb_chosen = np.load(dir_rgb_foregrounds + '/' + foreground_rgb_chosen_name) + + img_with_overlayed_background = return_binary_image_with_given_rgb_background_and_given_foreground_rgb(img_bin_corr, img_rgb_background_chosen, foreground_rgb_chosen) + + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + img_with_overlayed_background, + cv2.imread(dir_of_label_file), + input_height, input_width, indexer=indexer) + + + if add_red_textlines: + img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') + img_red_context = return_image_with_red_elements(cv2.imread(dir_img + '/'+im), img_bin_corr) + + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + img_red_context, + cv2.imread(dir_of_label_file), + input_height, input_width, indexer=indexer) + + if flip_aug: + for f_i in flip_index: + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + cv2.flip(cv2.imread(dir_img + '/' + im), f_i), + cv2.flip(cv2.imread(dir_of_label_file), f_i), + input_height, input_width, indexer=indexer) + if blur_aug: + for blur_i in blur_k: + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + bluring(cv2.imread(dir_img + '/' + im), blur_i), + cv2.imread(dir_of_label_file), + input_height, input_width, indexer=indexer) + if padding_black: + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + do_padding_black(cv2.imread(dir_img + '/' + im)), + do_padding_label(cv2.imread(dir_of_label_file)), + input_height, input_width, indexer=indexer) + + if padding_white: + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + do_padding_white(cv2.imread(dir_img + '/'+im)), + do_padding_label(cv2.imread(dir_of_label_file)), + input_height, input_width, indexer=indexer) + + if brightening: + for factor in brightness: + try: + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + do_brightening(dir_img + '/' +im, factor), + cv2.imread(dir_of_label_file), + input_height, input_width, indexer=indexer) + except: + pass + if scaling: + for sc_ind in scales: + indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels, + cv2.imread(dir_img + '/' + im) , + cv2.imread(dir_of_label_file), + input_height, input_width, indexer=indexer, scaler=sc_ind) + + if degrading: + for degrade_scale_ind in degrade_scales: + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + do_degrading(cv2.imread(dir_img + '/' + im), degrade_scale_ind), + cv2.imread(dir_of_label_file), + input_height, input_width, indexer=indexer) + + if binarization: + if dir_img_bin: + img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') + + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + img_bin_corr, + cv2.imread(dir_of_label_file), + input_height, input_width, indexer=indexer) + + else: + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + otsu_copy(cv2.imread(dir_img + '/' + im)), + cv2.imread(dir_of_label_file), + input_height, input_width, indexer=indexer) + + if scaling_brightness: + for sc_ind in scales: + for factor in brightness: + try: + indexer = get_patches_num_scale_new(dir_flow_train_imgs, + dir_flow_train_labels, + do_brightening(dir_img + '/' + im, factor) + ,cv2.imread(dir_of_label_file) + ,input_height, input_width, indexer=indexer, scaler=sc_ind) + except: + pass + + if scaling_bluring: + for sc_ind in scales: + for blur_i in blur_k: + indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels, + bluring(cv2.imread(dir_img + '/' + im), blur_i), + cv2.imread(dir_of_label_file), + input_height, input_width, indexer=indexer, scaler=sc_ind) + + if scaling_binarization: + for sc_ind in scales: + indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels, + otsu_copy(cv2.imread(dir_img + '/' + im)), + cv2.imread(dir_of_label_file), + input_height, input_width, indexer=indexer, scaler=sc_ind) + + if scaling_flip: + for sc_ind in scales: + for f_i in flip_index: + indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels, + cv2.flip( cv2.imread(dir_img + '/' + im), f_i), + cv2.flip(cv2.imread(dir_of_label_file), f_i), + input_height, input_width, indexer=indexer, scaler=sc_ind)