mirror of
				https://github.com/qurator-spk/eynollah.git
				synced 2025-10-22 06:14:20 +02:00 
			
		
		
		
	Merge pull request #193 from qurator-spk/training-installation
Training installation
This commit is contained in:
		
						commit
						3bd3faef68
					
				
					 12 changed files with 120 additions and 58 deletions
				
			
		|  | @ -13,7 +13,7 @@ The following three tasks can all be accomplished using the code in the | ||||||
| * train a model | * train a model | ||||||
| * inference with the trained model | * inference with the trained model | ||||||
| 
 | 
 | ||||||
| ## Training , evaluation and output  | ## Training, evaluation and output  | ||||||
| 
 | 
 | ||||||
| The train and evaluation folders should contain subfolders of `images` and `labels`. | The train and evaluation folders should contain subfolders of `images` and `labels`. | ||||||
| 
 | 
 | ||||||
|  | @ -22,11 +22,13 @@ The output folder should be an empty folder where the output model will be writt | ||||||
| ## Generate training dataset | ## Generate training dataset | ||||||
| 
 | 
 | ||||||
| The script `generate_gt_for_training.py` is used for generating training datasets. As the results of the following | The script `generate_gt_for_training.py` is used for generating training datasets. As the results of the following | ||||||
| command demonstrates, the dataset generator provides three different commands: | command demonstrates, the dataset generator provides several subcommands: | ||||||
| 
 | 
 | ||||||
| `python generate_gt_for_training.py --help` | ```sh | ||||||
|  | eynollah-training generate-gt --help | ||||||
|  | ``` | ||||||
| 
 | 
 | ||||||
| These three commands are: | The three most important subcommands are: | ||||||
| 
 | 
 | ||||||
| * image-enhancement | * image-enhancement | ||||||
| * machine-based-reading-order | * machine-based-reading-order | ||||||
|  | @ -38,7 +40,7 @@ Generating a training dataset for image enhancement is quite straightforward. Al | ||||||
| high-resolution images. The training dataset can then be generated using the following command: | high-resolution images. The training dataset can then be generated using the following command: | ||||||
| 
 | 
 | ||||||
| ```sh | ```sh | ||||||
| python generate_gt_for_training.py image-enhancement \ | eynollah-training image-enhancement \ | ||||||
|   -dis "dir of high resolution images" \ |   -dis "dir of high resolution images" \ | ||||||
|   -dois "dir where degraded images will be written" \ |   -dois "dir where degraded images will be written" \ | ||||||
|   -dols "dir where the corresponding high resolution image will be written as label" \ |   -dols "dir where the corresponding high resolution image will be written as label" \ | ||||||
|  | @ -69,7 +71,7 @@ to filter out regions smaller than this minimum size. This minimum size is defin | ||||||
| to the image area, with a default value of zero. To run the dataset generator, use the following command: | to the image area, with a default value of zero. To run the dataset generator, use the following command: | ||||||
| 
 | 
 | ||||||
| ```shell | ```shell | ||||||
| python generate_gt_for_training.py machine-based-reading-order \ | eynollah-training generate-gt machine-based-reading-order \ | ||||||
|   -dx "dir of GT xml files" \ |   -dx "dir of GT xml files" \ | ||||||
|   -domi "dir where output images will be written" \ |   -domi "dir where output images will be written" \ | ||||||
| "" -docl "dir where the labels will be written" \ | "" -docl "dir where the labels will be written" \ | ||||||
|  | @ -144,7 +146,7 @@ region" are also present in the label. However, other regions like "noise region | ||||||
| included in the label PNG file, even if they have information in the page XML files, as we chose not to include them. | included in the label PNG file, even if they have information in the page XML files, as we chose not to include them. | ||||||
| 
 | 
 | ||||||
| ```sh | ```sh | ||||||
| python generate_gt_for_training.py pagexml2label \ | eynollah-training generate-gt pagexml2label \ | ||||||
|   -dx "dir of GT xml files" \ |   -dx "dir of GT xml files" \ | ||||||
|   -do "dir where output label png files will be written" \ |   -do "dir where output label png files will be written" \ | ||||||
|   -cfg "custom config json file" \ |   -cfg "custom config json file" \ | ||||||
|  | @ -198,7 +200,7 @@ provided to ensure that they are cropped in sync with the labels. This ensures t | ||||||
| required for training are obtained. The command should resemble the following: | required for training are obtained. The command should resemble the following: | ||||||
| 
 | 
 | ||||||
| ```sh | ```sh | ||||||
| python generate_gt_for_training.py pagexml2label \ | eynollah-training generate-gt pagexml2label \ | ||||||
|   -dx "dir of GT xml files" \ |   -dx "dir of GT xml files" \ | ||||||
|   -do "dir where output label png files will be written" \ |   -do "dir where output label png files will be written" \ | ||||||
|   -cfg "custom config json file" \ |   -cfg "custom config json file" \ | ||||||
|  | @ -261,7 +263,7 @@ And the "dir_eval" the same structure as train directory: | ||||||
| The classification model can be trained using the following command line: | The classification model can be trained using the following command line: | ||||||
| 
 | 
 | ||||||
| ```sh | ```sh | ||||||
| python train.py with config_classification.json | eynollah-training train with config_classification.json | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
| As evident in the example JSON file above, for classification, we utilize a "f1_threshold_classification" parameter. | As evident in the example JSON file above, for classification, we utilize a "f1_threshold_classification" parameter. | ||||||
|  | @ -395,7 +397,9 @@ And the "dir_eval" the same structure as train directory: | ||||||
| After configuring the JSON file for segmentation or enhancement, training can be initiated by running the following | After configuring the JSON file for segmentation or enhancement, training can be initiated by running the following | ||||||
| command, similar to the process for classification and reading order: | command, similar to the process for classification and reading order: | ||||||
| 
 | 
 | ||||||
| `python train.py with config_classification.json` | ``` | ||||||
|  | eynollah-training train with config_classification.json` | ||||||
|  | ``` | ||||||
| 
 | 
 | ||||||
| #### Binarization | #### Binarization | ||||||
| 
 | 
 | ||||||
|  | @ -679,7 +683,7 @@ For conducting inference with a trained model, you simply need to execute the fo | ||||||
| directory of the model and the image on which to perform inference: | directory of the model and the image on which to perform inference: | ||||||
| 
 | 
 | ||||||
| ```sh | ```sh | ||||||
| python inference.py -m "model dir" -i "image" | eynollah-training inference -m "model dir" -i "image" | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
| This will straightforwardly return the class of the image. | This will straightforwardly return the class of the image. | ||||||
|  | @ -691,7 +695,7 @@ without the reading order. We simply need to provide the model directory, the XM | ||||||
| new XML file with the added reading order will be written to the output directory with the same name. We need to run: | new XML file with the added reading order will be written to the output directory with the same name. We need to run: | ||||||
| 
 | 
 | ||||||
| ```sh | ```sh | ||||||
| python inference.py \ | eynollah-training inference \ | ||||||
|   -m "model dir" \ |   -m "model dir" \ | ||||||
|   -xml "page xml file" \ |   -xml "page xml file" \ | ||||||
|   -o "output dir to write new xml with reading order" |   -o "output dir to write new xml with reading order" | ||||||
|  | @ -702,7 +706,7 @@ python inference.py \ | ||||||
| For conducting inference with a trained model for segmentation and enhancement you need to run the following command line: | For conducting inference with a trained model for segmentation and enhancement you need to run the following command line: | ||||||
| 
 | 
 | ||||||
| ```sh | ```sh | ||||||
| python inference.py \ | eynollah-training inference \ | ||||||
|   -m "model dir" \ |   -m "model dir" \ | ||||||
|   -i "image" \ |   -i "image" \ | ||||||
|   -p \ |   -p \ | ||||||
|  |  | ||||||
|  | @ -31,6 +31,7 @@ classifiers = [ | ||||||
| 
 | 
 | ||||||
| [project.scripts] | [project.scripts] | ||||||
| eynollah = "eynollah.cli:main" | eynollah = "eynollah.cli:main" | ||||||
|  | eynollah-training = "eynollah.training.cli:main" | ||||||
| ocrd-eynollah-segment = "eynollah.ocrd_cli:main" | ocrd-eynollah-segment = "eynollah.ocrd_cli:main" | ||||||
| ocrd-sbb-binarize = "eynollah.ocrd_cli_binarization:main" | ocrd-sbb-binarize = "eynollah.ocrd_cli_binarization:main" | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -1,20 +1,15 @@ | ||||||
| import os | import click | ||||||
| import sys |  | ||||||
| import tensorflow as tf | import tensorflow as tf | ||||||
| import warnings | 
 | ||||||
| from tensorflow.keras.optimizers import * | from .models import resnet50_unet | ||||||
| from sacred import Experiment |  | ||||||
| from models import * |  | ||||||
| from utils import * |  | ||||||
| from metrics import * |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def configuration(): | def configuration(): | ||||||
|     gpu_options = tf.compat.v1.GPUOptions(allow_growth=True) |     gpu_options = tf.compat.v1.GPUOptions(allow_growth=True) | ||||||
|     session = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options)) |     session = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options)) | ||||||
| 
 | 
 | ||||||
| 
 | @click.command() | ||||||
| if __name__ == '__main__': | def build_model_load_pretrained_weights_and_save(): | ||||||
|     n_classes = 2 |     n_classes = 2 | ||||||
|     input_height = 224 |     input_height = 224 | ||||||
|     input_width = 448 |     input_width = 448 | ||||||
							
								
								
									
										26
									
								
								src/eynollah/training/cli.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								src/eynollah/training/cli.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,26 @@ | ||||||
|  | import os | ||||||
|  | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  | ||||||
|  | 
 | ||||||
|  | import click | ||||||
|  | import sys | ||||||
|  | 
 | ||||||
|  | from .build_model_load_pretrained_weights_and_save import build_model_load_pretrained_weights_and_save | ||||||
|  | from .generate_gt_for_training import main as generate_gt_cli | ||||||
|  | from .inference import main as inference_cli | ||||||
|  | from .train import ex | ||||||
|  | 
 | ||||||
|  | @click.command(context_settings=dict( | ||||||
|  |         ignore_unknown_options=True, | ||||||
|  | )) | ||||||
|  | @click.argument('SACRED_ARGS', nargs=-1, type=click.UNPROCESSED) | ||||||
|  | def train_cli(sacred_args): | ||||||
|  |     ex.run_commandline([sys.argv[0]] + list(sacred_args)) | ||||||
|  | 
 | ||||||
|  | @click.group('training') | ||||||
|  | def main(): | ||||||
|  |     pass | ||||||
|  | 
 | ||||||
|  | main.add_command(build_model_load_pretrained_weights_and_save) | ||||||
|  | main.add_command(generate_gt_cli, 'generate-gt') | ||||||
|  | main.add_command(inference_cli, 'inference') | ||||||
|  | main.add_command(train_cli, 'train') | ||||||
|  | @ -1,9 +1,28 @@ | ||||||
| import click | import click | ||||||
| import json | import json | ||||||
| from gt_gen_utils import * | import os | ||||||
| from tqdm import tqdm | from tqdm import tqdm | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from PIL import Image, ImageDraw, ImageFont | from PIL import Image, ImageDraw, ImageFont | ||||||
|  | import cv2 | ||||||
|  | import numpy as np | ||||||
|  | 
 | ||||||
|  | from eynollah.training.gt_gen_utils import ( | ||||||
|  |     filter_contours_area_of_image, | ||||||
|  |     find_format_of_given_filename_in_dir, | ||||||
|  |     find_new_features_of_contours, | ||||||
|  |     fit_text_single_line, | ||||||
|  |     get_content_of_dir, | ||||||
|  |     get_images_of_ground_truth, | ||||||
|  |     get_layout_contours_for_visualization, | ||||||
|  |     get_textline_contours_and_ocr_text, | ||||||
|  |     get_textline_contours_for_visualization, | ||||||
|  |     overlay_layout_on_image, | ||||||
|  |     read_xml, | ||||||
|  |     resize_image, | ||||||
|  |     visualize_image_from_contours, | ||||||
|  |     visualize_image_from_contours_layout | ||||||
|  | ) | ||||||
| 
 | 
 | ||||||
| @click.group() | @click.group() | ||||||
| def main(): | def main(): | ||||||
|  | @ -562,6 +581,3 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out): | ||||||
|                     # Draw the text |                     # Draw the text | ||||||
|                     draw.text((text_x, text_y), ocr_texts[index], fill="black", font=font) |                     draw.text((text_x, text_y), ocr_texts[index], fill="black", font=font) | ||||||
|         image_text.save(os.path.join(dir_out, f_name+'.png')) |         image_text.save(os.path.join(dir_out, f_name+'.png')) | ||||||
|      |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     main() |  | ||||||
|  | @ -1,5 +1,3 @@ | ||||||
| import click |  | ||||||
| import sys |  | ||||||
| import os | import os | ||||||
| import numpy as np | import numpy as np | ||||||
| import warnings | import warnings | ||||||
|  | @ -8,8 +6,7 @@ from tqdm import tqdm | ||||||
| import cv2 | import cv2 | ||||||
| from shapely import geometry | from shapely import geometry | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| import matplotlib.pyplot as plt | from PIL import ImageFont | ||||||
| from PIL import Image, ImageDraw, ImageFont |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| KERNEL = np.ones((5, 5), np.uint8) | KERNEL = np.ones((5, 5), np.uint8) | ||||||
|  | @ -1,23 +1,29 @@ | ||||||
| import sys | import sys | ||||||
| import os | import os | ||||||
| import numpy as np |  | ||||||
| import warnings | import warnings | ||||||
|  | import json | ||||||
|  | 
 | ||||||
|  | import numpy as np | ||||||
| import cv2 | import cv2 | ||||||
| import seaborn as sns |  | ||||||
| from tensorflow.keras.models import load_model | from tensorflow.keras.models import load_model | ||||||
| import tensorflow as tf | import tensorflow as tf | ||||||
| from tensorflow.keras import backend as K | from tensorflow.keras import backend as K | ||||||
| from tensorflow.keras import layers |  | ||||||
| import tensorflow.keras.losses |  | ||||||
| from tensorflow.keras.layers import * | from tensorflow.keras.layers import * | ||||||
| from models import * |  | ||||||
| from gt_gen_utils import * |  | ||||||
| import click | import click | ||||||
| import json |  | ||||||
| from tensorflow.python.keras import backend as tensorflow_backend | from tensorflow.python.keras import backend as tensorflow_backend | ||||||
| import xml.etree.ElementTree as ET | import xml.etree.ElementTree as ET | ||||||
| import matplotlib.pyplot as plt |  | ||||||
| 
 | 
 | ||||||
|  | from .gt_gen_utils import ( | ||||||
|  |     filter_contours_area_of_image, | ||||||
|  |     find_new_features_of_contours, | ||||||
|  |     read_xml, | ||||||
|  |     resize_image, | ||||||
|  |     update_list_and_return_first_with_length_bigger_than_one | ||||||
|  | ) | ||||||
|  | from .models import ( | ||||||
|  |     PatchEncoder, | ||||||
|  |     Patches | ||||||
|  | ) | ||||||
| 
 | 
 | ||||||
| with warnings.catch_warnings(): | with warnings.catch_warnings(): | ||||||
|     warnings.simplefilter("ignore") |     warnings.simplefilter("ignore") | ||||||
|  | @ -55,11 +61,9 @@ class sbb_predict: | ||||||
|             seg=seg[:,:,0] |             seg=seg[:,:,0] | ||||||
|              |              | ||||||
|         seg_img=np.zeros((np.shape(seg)[0],np.shape(seg)[1],3)).astype(np.uint8) |         seg_img=np.zeros((np.shape(seg)[0],np.shape(seg)[1],3)).astype(np.uint8) | ||||||
|         colors=sns.color_palette("hls", self.n_classes) |  | ||||||
|          |          | ||||||
|         for c in ann_u: |         for c in ann_u: | ||||||
|             c=int(c) |             c=int(c) | ||||||
|             segl=(seg==c) |  | ||||||
|             seg_img[:,:,0][seg==c]=c |             seg_img[:,:,0][seg==c]=c | ||||||
|             seg_img[:,:,1][seg==c]=c |             seg_img[:,:,1][seg==c]=c | ||||||
|             seg_img[:,:,2][seg==c]=c |             seg_img[:,:,2][seg==c]=c | ||||||
|  | @ -674,9 +678,3 @@ def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_fil | ||||||
|     x=sbb_predict(image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area) |     x=sbb_predict(image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area) | ||||||
|     x.run() |     x.run() | ||||||
| 
 | 
 | ||||||
| if __name__=="__main__": |  | ||||||
|     main() |  | ||||||
| 
 |  | ||||||
|      |  | ||||||
|      |  | ||||||
|      |  | ||||||
|  | @ -1,20 +1,45 @@ | ||||||
| import os | import os | ||||||
| import sys | import sys | ||||||
|  | import json | ||||||
|  | 
 | ||||||
|  | import click | ||||||
|  | 
 | ||||||
|  | from eynollah.training.metrics import ( | ||||||
|  |     soft_dice_loss, | ||||||
|  |     weighted_categorical_crossentropy | ||||||
|  | ) | ||||||
|  | from eynollah.training.models import ( | ||||||
|  |     PatchEncoder, | ||||||
|  |     Patches, | ||||||
|  |     machine_based_reading_order_model, | ||||||
|  |     resnet50_classifier, | ||||||
|  |     resnet50_unet, | ||||||
|  |     vit_resnet50_unet, | ||||||
|  |     vit_resnet50_unet_transformer_before_cnn | ||||||
|  | ) | ||||||
|  | from eynollah.training.utils import ( | ||||||
|  |     data_gen, | ||||||
|  |     generate_arrays_from_folder_reading_order, | ||||||
|  |     generate_data_from_folder_evaluation, | ||||||
|  |     generate_data_from_folder_training, | ||||||
|  |     get_one_hot, | ||||||
|  |     provide_patches, | ||||||
|  |     return_number_of_total_training_data | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  | ||||||
| import tensorflow as tf | import tensorflow as tf | ||||||
| from tensorflow.compat.v1.keras.backend import set_session | from tensorflow.compat.v1.keras.backend import set_session | ||||||
| import warnings | from tensorflow.keras.optimizers import SGD, Adam | ||||||
| from tensorflow.keras.optimizers import * |  | ||||||
| from sacred import Experiment | from sacred import Experiment | ||||||
| from models import * |  | ||||||
| from utils import * |  | ||||||
| from metrics import * |  | ||||||
| from tensorflow.keras.models import load_model | from tensorflow.keras.models import load_model | ||||||
| from tqdm import tqdm | from tqdm import tqdm | ||||||
| import json |  | ||||||
| from sklearn.metrics import f1_score | from sklearn.metrics import f1_score | ||||||
| from tensorflow.keras.callbacks import Callback | from tensorflow.keras.callbacks import Callback | ||||||
| 
 | 
 | ||||||
|  | import numpy as np | ||||||
|  | import cv2 | ||||||
|  | 
 | ||||||
| class SaveWeightsAfterSteps(Callback): | class SaveWeightsAfterSteps(Callback): | ||||||
|     def __init__(self, save_interval, save_path, _config): |     def __init__(self, save_interval, save_path, _config): | ||||||
|         super(SaveWeightsAfterSteps, self).__init__() |         super(SaveWeightsAfterSteps, self).__init__() | ||||||
|  | @ -45,8 +70,8 @@ def configuration(): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_dirs_or_files(input_data): | def get_dirs_or_files(input_data): | ||||||
|  |     image_input, labels_input = os.path.join(input_data, 'images/'), os.path.join(input_data, 'labels/') | ||||||
|     if os.path.isdir(input_data): |     if os.path.isdir(input_data): | ||||||
|         image_input, labels_input = os.path.join(input_data, 'images/'), os.path.join(input_data, 'labels/') |  | ||||||
|         # Check if training dir exists |         # Check if training dir exists | ||||||
|         assert os.path.isdir(image_input), "{} is not a directory".format(image_input) |         assert os.path.isdir(image_input), "{} is not a directory".format(image_input) | ||||||
|         assert os.path.isdir(labels_input), "{} is not a directory".format(labels_input) |         assert os.path.isdir(labels_input), "{} is not a directory".format(labels_input) | ||||||
|  | @ -121,7 +146,6 @@ def config_params(): | ||||||
|     dir_rgb_backgrounds = None |     dir_rgb_backgrounds = None | ||||||
|     dir_rgb_foregrounds = None |     dir_rgb_foregrounds = None | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| @ex.automain | @ex.automain | ||||||
| def run(_config, n_classes, n_epochs, input_height, | def run(_config, n_classes, n_epochs, input_height, | ||||||
|         input_width, weight_decay, weighted_loss, |         input_width, weight_decay, weighted_loss, | ||||||
|  | @ -423,7 +447,7 @@ def run(_config, n_classes, n_epochs, input_height, | ||||||
| 
 | 
 | ||||||
|         #f1score_tot = [0] |         #f1score_tot = [0] | ||||||
|         indexer_start = 0 |         indexer_start = 0 | ||||||
|         opt = SGD(learning_rate=0.01, momentum=0.9) |         # opt = SGD(learning_rate=0.01, momentum=0.9) | ||||||
|         opt_adam = tf.keras.optimizers.Adam(learning_rate=0.0001) |         opt_adam = tf.keras.optimizers.Adam(learning_rate=0.0001) | ||||||
|         model.compile(loss="binary_crossentropy", |         model.compile(loss="binary_crossentropy", | ||||||
|                             optimizer = opt_adam,metrics=['accuracy']) |                             optimizer = opt_adam,metrics=['accuracy']) | ||||||
|  | @ -1,13 +1,14 @@ | ||||||
| import os | import os | ||||||
|  | import math | ||||||
|  | import random | ||||||
|  | 
 | ||||||
| import cv2 | import cv2 | ||||||
| import numpy as np | import numpy as np | ||||||
| import seaborn as sns | import seaborn as sns | ||||||
| from scipy.ndimage.interpolation import map_coordinates | from scipy.ndimage.interpolation import map_coordinates | ||||||
| from scipy.ndimage.filters import gaussian_filter | from scipy.ndimage.filters import gaussian_filter | ||||||
| import random |  | ||||||
| from tqdm import tqdm | from tqdm import tqdm | ||||||
| import imutils | import imutils | ||||||
| import math |  | ||||||
| from tensorflow.keras.utils import to_categorical | from tensorflow.keras.utils import to_categorical | ||||||
| from PIL import Image, ImageEnhance | from PIL import Image, ImageEnhance | ||||||
| 
 | 
 | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue