diff --git a/pyproject.toml b/pyproject.toml index 8ca6cff..ec3e5f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,6 @@ classifiers = [ eynollah = "eynollah.cli:main" ocrd-eynollah-segment = "eynollah.ocrd_cli:main" ocrd-sbb-binarize = "eynollah.ocrd_cli_binarization:main" -eynollah-training = "eynollah.training.cli:main" [project.urls] Homepage = "https://github.com/qurator-spk/eynollah" diff --git a/src/eynollah/training/build_model_load_pretrained_weights_and_save.py b/src/eynollah/training/build_model_load_pretrained_weights_and_save.py index 125611e..ce3d955 100644 --- a/src/eynollah/training/build_model_load_pretrained_weights_and_save.py +++ b/src/eynollah/training/build_model_load_pretrained_weights_and_save.py @@ -1,12 +1,7 @@ -import os -import sys import tensorflow as tf -import warnings from tensorflow.keras.optimizers import * -from sacred import Experiment -from models import * -from utils import * -from metrics import * + +from .models import resnet50_unet def configuration(): diff --git a/src/eynollah/training/generate_gt_for_training.py b/src/eynollah/training/generate_gt_for_training.py index d378c3e..3fd93ae 100644 --- a/src/eynollah/training/generate_gt_for_training.py +++ b/src/eynollah/training/generate_gt_for_training.py @@ -1,10 +1,28 @@ import click import json +import os from tqdm import tqdm from pathlib import Path from PIL import Image, ImageDraw, ImageFont +import cv2 +import numpy as np -from .gt_gen_utils import * +from eynollah.training.gt_gen_utils import ( + filter_contours_area_of_image, + find_format_of_given_filename_in_dir, + find_new_features_of_contours, + fit_text_single_line, + get_content_of_dir, + get_images_of_ground_truth, + get_layout_contours_for_visualization, + get_textline_contours_and_ocr_text, + get_textline_contours_for_visualization, + overlay_layout_on_image, + read_xml, + resize_image, + visualize_image_from_contours, + visualize_image_from_contours_layout +) @click.group() def main(): diff --git a/src/eynollah/training/inference.py b/src/eynollah/training/inference.py index 24837a1..998c8fc 100644 --- a/src/eynollah/training/inference.py +++ b/src/eynollah/training/inference.py @@ -13,9 +13,14 @@ import click from tensorflow.python.keras import backend as tensorflow_backend import xml.etree.ElementTree as ET -from .models import * -from .gt_gen_utils import * - +from .gt_gen_utils import ( + filter_contours_area_of_image, + find_new_features_of_contours, + read_xml, + resize_image, + update_list_and_return_first_with_length_bigger_than_one +) +from .models import PatchEncoder, Patches with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 3b99807..527bca6 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -2,20 +2,39 @@ import os import sys import json +from eynollah.training.metrics import soft_dice_loss, weighted_categorical_crossentropy + +from .models import ( + PatchEncoder, + Patches, + machine_based_reading_order_model, + resnet50_classifier, + resnet50_unet, + vit_resnet50_unet, + vit_resnet50_unet_transformer_before_cnn +) +from .utils import ( + data_gen, + generate_arrays_from_folder_reading_order, + generate_data_from_folder_evaluation, + generate_data_from_folder_training, + get_one_hot, + provide_patches, + return_number_of_total_training_data +) + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' import tensorflow as tf from tensorflow.compat.v1.keras.backend import set_session -import warnings -from tensorflow.keras.optimizers import * +from tensorflow.keras.optimizers import SGD, Adam from sacred import Experiment from tensorflow.keras.models import load_model from tqdm import tqdm from sklearn.metrics import f1_score from tensorflow.keras.callbacks import Callback -from .models import * -from .utils import * -from .metrics import * +import numpy as np +import cv2 class SaveWeightsAfterSteps(Callback): def __init__(self, save_interval, save_path, _config): @@ -47,8 +66,8 @@ def configuration(): def get_dirs_or_files(input_data): + image_input, labels_input = os.path.join(input_data, 'images/'), os.path.join(input_data, 'labels/') if os.path.isdir(input_data): - image_input, labels_input = os.path.join(input_data, 'images/'), os.path.join(input_data, 'labels/') # Check if training dir exists assert os.path.isdir(image_input), "{} is not a directory".format(image_input) assert os.path.isdir(labels_input), "{} is not a directory".format(labels_input) @@ -425,7 +444,7 @@ def run(_config, n_classes, n_epochs, input_height, #f1score_tot = [0] indexer_start = 0 - opt = SGD(learning_rate=0.01, momentum=0.9) + # opt = SGD(learning_rate=0.01, momentum=0.9) opt_adam = tf.keras.optimizers.Adam(learning_rate=0.0001) model.compile(loss="binary_crossentropy", optimizer = opt_adam,metrics=['accuracy'])