From 95635d5b9ce17b3c417e3869c9586181ede6f384 Mon Sep 17 00:00:00 2001 From: "Rezanezhad, Vahid" Date: Thu, 5 Dec 2019 12:01:54 +0100 Subject: [PATCH 001/123] code to produce models --- train/.gitkeep | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 train/.gitkeep diff --git a/train/.gitkeep b/train/.gitkeep new file mode 100644 index 0000000..e69de29 From 4601237427f8b8cc2786a3bf845dbec7dfbd289d Mon Sep 17 00:00:00 2001 From: b-vr103 Date: Thu, 5 Dec 2019 12:10:55 +0100 Subject: [PATCH 002/123] add files needed for training --- train/__init__.py | 0 train/metrics.py | 338 ++++++++++++++++++++++++++++++++++++++++++++++ train/models.py | 317 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 655 insertions(+) create mode 100644 train/__init__.py create mode 100644 train/metrics.py create mode 100644 train/models.py diff --git a/train/__init__.py b/train/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/train/metrics.py b/train/metrics.py new file mode 100644 index 0000000..c63cc22 --- /dev/null +++ b/train/metrics.py @@ -0,0 +1,338 @@ +from keras import backend as K +import tensorflow as tf +import numpy as np + +def focal_loss(gamma=2., alpha=4.): + + gamma = float(gamma) + alpha = float(alpha) + + def focal_loss_fixed(y_true, y_pred): + """Focal loss for multi-classification + FL(p_t)=-alpha(1-p_t)^{gamma}ln(p_t) + Notice: y_pred is probability after softmax + gradient is d(Fl)/d(p_t) not d(Fl)/d(x) as described in paper + d(Fl)/d(p_t) * [p_t(1-p_t)] = d(Fl)/d(x) + Focal Loss for Dense Object Detection + https://arxiv.org/abs/1708.02002 + + Arguments: + y_true {tensor} -- ground truth labels, shape of [batch_size, num_cls] + y_pred {tensor} -- model's output, shape of [batch_size, num_cls] + + Keyword Arguments: + gamma {float} -- (default: {2.0}) + alpha {float} -- (default: {4.0}) + + Returns: + [tensor] -- loss. + """ + epsilon = 1.e-9 + y_true = tf.convert_to_tensor(y_true, tf.float32) + y_pred = tf.convert_to_tensor(y_pred, tf.float32) + + model_out = tf.add(y_pred, epsilon) + ce = tf.multiply(y_true, -tf.log(model_out)) + weight = tf.multiply(y_true, tf.pow(tf.subtract(1., model_out), gamma)) + fl = tf.multiply(alpha, tf.multiply(weight, ce)) + reduced_fl = tf.reduce_max(fl, axis=1) + return tf.reduce_mean(reduced_fl) + return focal_loss_fixed + +def weighted_categorical_crossentropy(weights=None): + """ weighted_categorical_crossentropy + + Args: + * weights: crossentropy weights + Returns: + * weighted categorical crossentropy function + """ + + def loss(y_true, y_pred): + labels_floats = tf.cast(y_true, tf.float32) + per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred) + + if weights is not None: + weight_mask = tf.maximum(tf.reduce_max(tf.constant( + np.array(weights, dtype=np.float32)[None, None, None]) + * labels_floats, axis=-1), 1.0) + per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None] + return tf.reduce_mean(per_pixel_loss) + return loss +def image_categorical_cross_entropy(y_true, y_pred, weights=None): + """ + :param y_true: tensor of shape (batch_size, height, width) representing the ground truth. + :param y_pred: tensor of shape (batch_size, height, width) representing the prediction. + :return: The mean cross-entropy on softmaxed tensors. + """ + + labels_floats = tf.cast(y_true, tf.float32) + per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred) + + if weights is not None: + weight_mask = tf.maximum( + tf.reduce_max(tf.constant( + np.array(weights, dtype=np.float32)[None, None, None]) + * labels_floats, axis=-1), 1.0) + per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None] + + return tf.reduce_mean(per_pixel_loss) +def class_tversky(y_true, y_pred): + smooth = 1.0#1.00 + + y_true = K.permute_dimensions(y_true, (3,1,2,0)) + y_pred = K.permute_dimensions(y_pred, (3,1,2,0)) + + y_true_pos = K.batch_flatten(y_true) + y_pred_pos = K.batch_flatten(y_pred) + true_pos = K.sum(y_true_pos * y_pred_pos, 1) + false_neg = K.sum(y_true_pos * (1-y_pred_pos), 1) + false_pos = K.sum((1-y_true_pos)*y_pred_pos, 1) + alpha = 0.2#0.5 + beta=0.8 + return (true_pos + smooth)/(true_pos + alpha*false_neg + (beta)*false_pos + smooth) + +def focal_tversky_loss(y_true,y_pred): + pt_1 = class_tversky(y_true, y_pred) + gamma =1.3#4./3.0#1.3#4.0/3.00# 0.75 + return K.sum(K.pow((1-pt_1), gamma)) + +def generalized_dice_coeff2(y_true, y_pred): + n_el = 1 + for dim in y_true.shape: + n_el *= int(dim) + n_cl = y_true.shape[-1] + w = K.zeros(shape=(n_cl,)) + w = (K.sum(y_true, axis=(0,1,2)))/(n_el) + w = 1/(w**2+0.000001) + numerator = y_true*y_pred + numerator = w*K.sum(numerator,(0,1,2)) + numerator = K.sum(numerator) + denominator = y_true+y_pred + denominator = w*K.sum(denominator,(0,1,2)) + denominator = K.sum(denominator) + return 2*numerator/denominator +def generalized_dice_coeff(y_true, y_pred): + axes = tuple(range(1, len(y_pred.shape)-1)) + Ncl = y_pred.shape[-1] + w = K.zeros(shape=(Ncl,)) + w = K.sum(y_true, axis=axes) + w = 1/(w**2+0.000001) + # Compute gen dice coef: + numerator = y_true*y_pred + numerator = w*K.sum(numerator,axes) + numerator = K.sum(numerator) + + denominator = y_true+y_pred + denominator = w*K.sum(denominator,axes) + denominator = K.sum(denominator) + + gen_dice_coef = 2*numerator/denominator + + return gen_dice_coef + +def generalized_dice_loss(y_true, y_pred): + return 1 - generalized_dice_coeff2(y_true, y_pred) +def soft_dice_loss(y_true, y_pred, epsilon=1e-6): + ''' + Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions. + Assumes the `channels_last` format. + + # Arguments + y_true: b x X x Y( x Z...) x c One hot encoding of ground truth + y_pred: b x X x Y( x Z...) x c Network output, must sum to 1 over c channel (such as after softmax) + epsilon: Used for numerical stability to avoid divide by zero errors + + # References + V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation + https://arxiv.org/abs/1606.04797 + More details on Dice loss formulation + https://mediatum.ub.tum.de/doc/1395260/1395260.pdf (page 72) + + Adapted from https://github.com/Lasagne/Recipes/issues/99#issuecomment-347775022 + ''' + + # skip the batch and class axis for calculating Dice score + axes = tuple(range(1, len(y_pred.shape)-1)) + + numerator = 2. * K.sum(y_pred * y_true, axes) + + denominator = K.sum(K.square(y_pred) + K.square(y_true), axes) + return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch + +def seg_metrics(y_true, y_pred, metric_name, metric_type='standard', drop_last = True, mean_per_class=False, verbose=False): + """ + Compute mean metrics of two segmentation masks, via Keras. + + IoU(A,B) = |A & B| / (| A U B|) + Dice(A,B) = 2*|A & B| / (|A| + |B|) + + Args: + y_true: true masks, one-hot encoded. + y_pred: predicted masks, either softmax outputs, or one-hot encoded. + metric_name: metric to be computed, either 'iou' or 'dice'. + metric_type: one of 'standard' (default), 'soft', 'naive'. + In the standard version, y_pred is one-hot encoded and the mean + is taken only over classes that are present (in y_true or y_pred). + The 'soft' version of the metrics are computed without one-hot + encoding y_pred. + The 'naive' version return mean metrics where absent classes contribute + to the class mean as 1.0 (instead of being dropped from the mean). + drop_last = True: boolean flag to drop last class (usually reserved + for background class in semantic segmentation) + mean_per_class = False: return mean along batch axis for each class. + verbose = False: print intermediate results such as intersection, union + (as number of pixels). + Returns: + IoU/Dice of y_true and y_pred, as a float, unless mean_per_class == True + in which case it returns the per-class metric, averaged over the batch. + + Inputs are B*W*H*N tensors, with + B = batch size, + W = width, + H = height, + N = number of classes + """ + + flag_soft = (metric_type == 'soft') + flag_naive_mean = (metric_type == 'naive') + + # always assume one or more classes + num_classes = K.shape(y_true)[-1] + + if not flag_soft: + # get one-hot encoded masks from y_pred (true masks should already be one-hot) + y_pred = K.one_hot(K.argmax(y_pred), num_classes) + y_true = K.one_hot(K.argmax(y_true), num_classes) + + # if already one-hot, could have skipped above command + # keras uses float32 instead of float64, would give error down (but numpy arrays or keras.to_categorical gives float64) + y_true = K.cast(y_true, 'float32') + y_pred = K.cast(y_pred, 'float32') + + # intersection and union shapes are batch_size * n_classes (values = area in pixels) + axes = (1,2) # W,H axes of each image + intersection = K.sum(K.abs(y_true * y_pred), axis=axes) + mask_sum = K.sum(K.abs(y_true), axis=axes) + K.sum(K.abs(y_pred), axis=axes) + union = mask_sum - intersection # or, np.logical_or(y_pred, y_true) for one-hot + + smooth = .001 + iou = (intersection + smooth) / (union + smooth) + dice = 2 * (intersection + smooth)/(mask_sum + smooth) + + metric = {'iou': iou, 'dice': dice}[metric_name] + + # define mask to be 0 when no pixels are present in either y_true or y_pred, 1 otherwise + mask = K.cast(K.not_equal(union, 0), 'float32') + + if drop_last: + metric = metric[:,:-1] + mask = mask[:,:-1] + + if verbose: + print('intersection, union') + print(K.eval(intersection), K.eval(union)) + print(K.eval(intersection/union)) + + # return mean metrics: remaining axes are (batch, classes) + if flag_naive_mean: + return K.mean(metric) + + # take mean only over non-absent classes + class_count = K.sum(mask, axis=0) + non_zero = tf.greater(class_count, 0) + non_zero_sum = tf.boolean_mask(K.sum(metric * mask, axis=0), non_zero) + non_zero_count = tf.boolean_mask(class_count, non_zero) + + if verbose: + print('Counts of inputs with class present, metrics for non-absent classes') + print(K.eval(class_count), K.eval(non_zero_sum / non_zero_count)) + + return K.mean(non_zero_sum / non_zero_count) + +def mean_iou(y_true, y_pred, **kwargs): + """ + Compute mean Intersection over Union of two segmentation masks, via Keras. + + Calls metrics_k(y_true, y_pred, metric_name='iou'), see there for allowed kwargs. + """ + return seg_metrics(y_true, y_pred, metric_name='iou', **kwargs) +def Mean_IOU(y_true, y_pred): + nb_classes = K.int_shape(y_pred)[-1] + iou = [] + true_pixels = K.argmax(y_true, axis=-1) + pred_pixels = K.argmax(y_pred, axis=-1) + void_labels = K.equal(K.sum(y_true, axis=-1), 0) + for i in range(0, nb_classes): # exclude first label (background) and last label (void) + true_labels = K.equal(true_pixels, i)# & ~void_labels + pred_labels = K.equal(pred_pixels, i)# & ~void_labels + inter = tf.to_int32(true_labels & pred_labels) + union = tf.to_int32(true_labels | pred_labels) + legal_batches = K.sum(tf.to_int32(true_labels), axis=1)>0 + ious = K.sum(inter, axis=1)/K.sum(union, axis=1) + iou.append(K.mean(tf.gather(ious, indices=tf.where(legal_batches)))) # returns average IoU of the same objects + iou = tf.stack(iou) + legal_labels = ~tf.debugging.is_nan(iou) + iou = tf.gather(iou, indices=tf.where(legal_labels)) + return K.mean(iou) + +def iou_vahid(y_true, y_pred): + nb_classes = tf.shape(y_true)[-1]+tf.to_int32(1) + true_pixels = K.argmax(y_true, axis=-1) + pred_pixels = K.argmax(y_pred, axis=-1) + iou = [] + + for i in tf.range(nb_classes): + tp=K.sum( tf.to_int32( K.equal(true_pixels, i) & K.equal(pred_pixels, i) ) ) + fp=K.sum( tf.to_int32( K.not_equal(true_pixels, i) & K.equal(pred_pixels, i) ) ) + fn=K.sum( tf.to_int32( K.equal(true_pixels, i) & K.not_equal(pred_pixels, i) ) ) + iouh=tp/(tp+fp+fn) + iou.append(iouh) + return K.mean(iou) + + +def IoU_metric(Yi,y_predi): + ## mean Intersection over Union + ## Mean IoU = TP/(FN + TP + FP) + y_predi = np.argmax(y_predi, axis=3) + y_testi = np.argmax(Yi, axis=3) + IoUs = [] + Nclass = int(np.max(Yi)) + 1 + for c in range(Nclass): + TP = np.sum( (Yi == c)&(y_predi==c) ) + FP = np.sum( (Yi != c)&(y_predi==c) ) + FN = np.sum( (Yi == c)&(y_predi != c)) + IoU = TP/float(TP + FP + FN) + IoUs.append(IoU) + return K.cast( np.mean(IoUs) ,dtype='float32' ) + + +def IoU_metric_keras(y_true, y_pred): + ## mean Intersection over Union + ## Mean IoU = TP/(FN + TP + FP) + init = tf.global_variables_initializer() + sess = tf.Session() + sess.run(init) + + return IoU_metric(y_true.eval(session=sess), y_pred.eval(session=sess)) + +def jaccard_distance_loss(y_true, y_pred, smooth=100): + """ + Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|) + = sum(|A*B|)/(sum(|A|)+sum(|B|)-sum(|A*B|)) + + The jaccard distance loss is usefull for unbalanced datasets. This has been + shifted so it converges on 0 and is smoothed to avoid exploding or disapearing + gradient. + + Ref: https://en.wikipedia.org/wiki/Jaccard_index + + @url: https://gist.github.com/wassname/f1452b748efcbeb4cb9b1d059dce6f96 + @author: wassname + """ + intersection = K.sum(K.abs(y_true * y_pred), axis=-1) + sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1) + jac = (intersection + smooth) / (sum_ - intersection + smooth) + return (1 - jac) * smooth + + diff --git a/train/models.py b/train/models.py new file mode 100644 index 0000000..7c806b4 --- /dev/null +++ b/train/models.py @@ -0,0 +1,317 @@ +from keras.models import * +from keras.layers import * +from keras import layers +from keras.regularizers import l2 + +resnet50_Weights_path='./pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5' +IMAGE_ORDERING ='channels_last' +MERGE_AXIS=-1 + + +def one_side_pad( x ): + x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x) + if IMAGE_ORDERING == 'channels_first': + x = Lambda(lambda x : x[: , : , :-1 , :-1 ] )(x) + elif IMAGE_ORDERING == 'channels_last': + x = Lambda(lambda x : x[: , :-1 , :-1 , : ] )(x) + return x + +def identity_block(input_tensor, kernel_size, filters, stage, block): + """The identity block is the block that has no conv layer at shortcut. + # Arguments + input_tensor: input tensor + kernel_size: defualt 3, the kernel size of middle conv layer at main path + filters: list of integers, the filterss of 3 conv layer at main path + stage: integer, current stage label, used for generating layer names + block: 'a','b'..., current block label, used for generating layer names + # Returns + Output tensor for the block. + """ + filters1, filters2, filters3 = filters + + if IMAGE_ORDERING == 'channels_last': + bn_axis = 3 + else: + bn_axis = 1 + + conv_name_base = 'res' + str(stage) + block + '_branch' + bn_name_base = 'bn' + str(stage) + block + '_branch' + + x = Conv2D(filters1, (1, 1) , data_format=IMAGE_ORDERING , name=conv_name_base + '2a')(input_tensor) + x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) + x = Activation('relu')(x) + + x = Conv2D(filters2, kernel_size , data_format=IMAGE_ORDERING , + padding='same', name=conv_name_base + '2b')(x) + x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) + x = Activation('relu')(x) + + x = Conv2D(filters3 , (1, 1), data_format=IMAGE_ORDERING , name=conv_name_base + '2c')(x) + x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) + + x = layers.add([x, input_tensor]) + x = Activation('relu')(x) + return x + + +def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)): + """conv_block is the block that has a conv layer at shortcut + # Arguments + input_tensor: input tensor + kernel_size: defualt 3, the kernel size of middle conv layer at main path + filters: list of integers, the filterss of 3 conv layer at main path + stage: integer, current stage label, used for generating layer names + block: 'a','b'..., current block label, used for generating layer names + # Returns + Output tensor for the block. + Note that from stage 3, the first conv layer at main path is with strides=(2,2) + And the shortcut should have strides=(2,2) as well + """ + filters1, filters2, filters3 = filters + + if IMAGE_ORDERING == 'channels_last': + bn_axis = 3 + else: + bn_axis = 1 + + conv_name_base = 'res' + str(stage) + block + '_branch' + bn_name_base = 'bn' + str(stage) + block + '_branch' + + x = Conv2D(filters1, (1, 1) , data_format=IMAGE_ORDERING , strides=strides, + name=conv_name_base + '2a')(input_tensor) + x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) + x = Activation('relu')(x) + + x = Conv2D(filters2, kernel_size , data_format=IMAGE_ORDERING , padding='same', + name=conv_name_base + '2b')(x) + x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) + x = Activation('relu')(x) + + x = Conv2D(filters3, (1, 1) , data_format=IMAGE_ORDERING , name=conv_name_base + '2c')(x) + x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) + + shortcut = Conv2D(filters3, (1, 1) , data_format=IMAGE_ORDERING , strides=strides, + name=conv_name_base + '1')(input_tensor) + shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut) + + x = layers.add([x, shortcut]) + x = Activation('relu')(x) + return x + + +def resnet50_unet_light(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False): + assert input_height%32 == 0 + assert input_width%32 == 0 + + + img_input = Input(shape=(input_height,input_width , 3 )) + + if IMAGE_ORDERING == 'channels_last': + bn_axis = 3 + else: + bn_axis = 1 + + x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input) + x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x) + f1 = x + + x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) + x = Activation('relu')(x) + x = MaxPooling2D((3, 3) , data_format=IMAGE_ORDERING , strides=(2, 2))(x) + + + x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) + x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') + x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') + f2 = one_side_pad(x ) + + + x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') + f3 = x + + x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') + f4 = x + + x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') + f5 = x + + + if pretraining: + model=Model( img_input , x ).load_weights(resnet50_Weights_path) + + + v512_2048 = Conv2D( 512 , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( f5 ) + v512_2048 = ( BatchNormalization(axis=bn_axis))(v512_2048) + v512_2048 = Activation('relu')(v512_2048) + + + + v512_1024=Conv2D( 512 , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( f4 ) + v512_1024 = ( BatchNormalization(axis=bn_axis))(v512_1024) + v512_1024 = Activation('relu')(v512_1024) + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(v512_2048) + o = ( concatenate([ o ,v512_1024],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D( (1,1), data_format=IMAGE_ORDERING))(o) + o = ( Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) + o = ( concatenate([ o ,f3],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D( (1,1), data_format=IMAGE_ORDERING))(o) + o = ( Conv2D( 256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) + o = ( concatenate([o,f2],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING))(o) + o = ( Conv2D( 128 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay) ) )(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) + o = ( concatenate([o,f1],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING ))(o) + o = ( Conv2D( 64 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) ))(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) + o = ( concatenate([o,img_input],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING ))(o) + o = ( Conv2D( 32 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) ))(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + + o = Conv2D( n_classes , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( o ) + o = ( BatchNormalization(axis=bn_axis))(o) + o = (Activation('softmax'))(o) + + + model = Model( img_input , o ) + return model + +def resnet50_unet(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False): + assert input_height%32 == 0 + assert input_width%32 == 0 + + + img_input = Input(shape=(input_height,input_width , 3 )) + + if IMAGE_ORDERING == 'channels_last': + bn_axis = 3 + else: + bn_axis = 1 + + x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input) + x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x) + f1 = x + + x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) + x = Activation('relu')(x) + x = MaxPooling2D((3, 3) , data_format=IMAGE_ORDERING , strides=(2, 2))(x) + + + x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) + x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') + x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') + f2 = one_side_pad(x ) + + + x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') + f3 = x + + x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') + f4 = x + + x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') + f5 = x + + if pretraining: + Model( img_input , x ).load_weights(resnet50_Weights_path) + + v1024_2048 = Conv2D( 1024 , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( f5 ) + v1024_2048 = ( BatchNormalization(axis=bn_axis))(v1024_2048) + v1024_2048 = Activation('relu')(v1024_2048) + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(v1024_2048) + o = ( concatenate([ o ,f4],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D( (1,1), data_format=IMAGE_ORDERING))(o) + o = ( Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) + o = ( concatenate([ o ,f3],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D( (1,1), data_format=IMAGE_ORDERING))(o) + o = ( Conv2D( 256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) + o = ( concatenate([o,f2],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING))(o) + o = ( Conv2D( 128 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay) ) )(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) + o = ( concatenate([o,f1],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING ))(o) + o = ( Conv2D( 64 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) ))(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) + o = ( concatenate([o,img_input],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING ))(o) + o = ( Conv2D( 32 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) ))(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + o = Conv2D( n_classes , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( o ) + o = ( BatchNormalization(axis=bn_axis))(o) + o = (Activation('softmax'))(o) + + model = Model( img_input , o ) + + + + + return model From 226330535d0d01c67e4c18c7957e3d69b8f5f672 Mon Sep 17 00:00:00 2001 From: b-vr103 Date: Thu, 5 Dec 2019 14:05:07 +0100 Subject: [PATCH 003/123] add files needed for training --- train/README | 23 +++ train/config_params.json | 24 +++ train/train.py | 192 ++++++++++++++++++++++ train/utils.py | 336 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 575 insertions(+) create mode 100644 train/README create mode 100644 train/config_params.json create mode 100644 train/train.py create mode 100644 train/utils.py diff --git a/train/README b/train/README new file mode 100644 index 0000000..7d8d790 --- /dev/null +++ b/train/README @@ -0,0 +1,23 @@ +how to train: + just run: python train.py with config_params.json + + +format of ground truth: + + Lables for each pixel is identified by a number . So if you have a binary case n_classes should be set to 2 and labels should be 0 and 1 for each class and pixel. + In the case of multiclass just set n_classes to the number of classes you have and the try to produce the labels by pixels from 0 , 1 ,2 .., n_classes-1. + The labels format should be png. + + If you have an image label for binary case it should look like this: + + Label: [ [[1 0 0 1], [1 0 0 1] ,[1 0 0 1]], [[1 0 0 1], [1 0 0 1] ,[1 0 0 1]] ,[[1 0 0 1], [1 0 0 1] ,[1 0 0 1]] ] this means that you have an image by 3*4*3 and pixel[0,0] belongs to class 1 and pixel[0,1] to class 0. + +traing , evaluation and output: + train and evaluation folder should have subfolder of images and labels. + And output folder should be free folder which the output model will be written there. + +patches: + + if you want to train your model with patches, the height and width of patches should be defined and also number of batchs (how many patches should be seen by model by each iteration). + In the case that model should see the image once, like page extraction, the patches should be set to false. + diff --git a/train/config_params.json b/train/config_params.json new file mode 100644 index 0000000..52db6db --- /dev/null +++ b/train/config_params.json @@ -0,0 +1,24 @@ +{ + "n_classes" : 2, + "n_epochs" : 2, + "input_height" : 448, + "input_width" : 896, + "weight_decay" : 1e-6, + "n_batch" : 1, + "learning_rate": 1e-4, + "patches" : true, + "pretraining" : true, + "augmentation" : false, + "flip_aug" : false, + "elastic_aug" : false, + "blur_aug" : false, + "scaling" : false, + "binarization" : false, + "scaling_bluring" : false, + "scaling_binarization" : false, + "rotation": false, + "weighted_loss": true, + "dir_train": "/home/vahid/textline_gt_images/train_light", + "dir_eval": "/home/vahid/textline_gt_images/eval", + "dir_output": "/home/vahid/textline_gt_images/output" +} diff --git a/train/train.py b/train/train.py new file mode 100644 index 0000000..07c7418 --- /dev/null +++ b/train/train.py @@ -0,0 +1,192 @@ +import os +import sys +import tensorflow as tf +from keras.backend.tensorflow_backend import set_session +import keras , warnings +from keras.optimizers import * +from sacred import Experiment +from models import * +from utils import * +from metrics import * + + +def configuration(): + keras.backend.clear_session() + tf.reset_default_graph() + warnings.filterwarnings('ignore') + + os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID' + config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True) + + + config.gpu_options.allow_growth = True + config.gpu_options.per_process_gpu_memory_fraction=0.95#0.95 + config.gpu_options.visible_device_list="0" + set_session(tf.Session(config=config)) + +def get_dirs_or_files(input_data): + if os.path.isdir(input_data): + image_input, labels_input = os.path.join(input_data, 'images/'), os.path.join(input_data, 'labels/') + # Check if training dir exists + assert os.path.isdir(image_input), "{} is not a directory".format(image_input) + assert os.path.isdir(labels_input), "{} is not a directory".format(labels_input) + return image_input, labels_input + +ex = Experiment() + +@ex.config +def config_params(): + n_classes=None # Number of classes. If your case study is binary case the set it to 2 and otherwise give your number of cases. + n_epochs=1 + input_height=224*1 + input_width=224*1 + weight_decay=1e-6 # Weight decay of l2 regularization of model layers. + n_batch=1 # Number of batches at each iteration. + learning_rate=1e-4 + patches=False # Make patches of image in order to use all information of image. In the case of page + # extraction this should be set to false since model should see all image. + augmentation=False + flip_aug=False # Flip image (augmentation). + elastic_aug=False # Elastic transformation (augmentation). + blur_aug=False # Blur patches of image (augmentation). + scaling=False # Scaling of patches (augmentation) will be imposed if this set to true. + binarization=False # Otsu thresholding. Used for augmentation in the case of binary case like textline prediction. For multicases should not be applied. + dir_train=None # Directory of training dataset (sub-folders should be named images and labels). + dir_eval=None # Directory of validation dataset (sub-folders should be named images and labels). + dir_output=None # Directory of output where the model should be saved. + pretraining=False # Set true to load pretrained weights of resnet50 encoder. + weighted_loss=False # Set True if classes are unbalanced and you want to use weighted loss function. + scaling_bluring=False + rotation: False + scaling_binarization=False + blur_k=['blur','guass','median'] # Used in order to blur image. Used for augmentation. + scales=[0.9 , 1.1 ] # Scale patches with these scales. Used for augmentation. + flip_index=[0,1] # Flip image. Used for augmentation. + + +@ex.automain +def run(n_classes,n_epochs,input_height, + input_width,weight_decay,weighted_loss, + n_batch,patches,augmentation,flip_aug,blur_aug,scaling, binarization, + blur_k,scales,dir_train, + scaling_bluring,scaling_binarization,rotation, + flip_index,dir_eval ,dir_output,pretraining,learning_rate): + + dir_img,dir_seg=get_dirs_or_files(dir_train) + dir_img_val,dir_seg_val=get_dirs_or_files(dir_eval) + + # make first a directory in output for both training and evaluations in order to flow data from these directories. + dir_train_flowing=os.path.join(dir_output,'train') + dir_eval_flowing=os.path.join(dir_output,'eval') + + dir_flow_train_imgs=os.path.join(dir_train_flowing,'images') + dir_flow_train_labels=os.path.join(dir_train_flowing,'labels') + + dir_flow_eval_imgs=os.path.join(dir_eval_flowing,'images') + dir_flow_eval_labels=os.path.join(dir_eval_flowing,'labels') + + if os.path.isdir(dir_train_flowing): + os.system('rm -rf '+dir_train_flowing) + os.makedirs(dir_train_flowing) + else: + os.makedirs(dir_train_flowing) + + if os.path.isdir(dir_eval_flowing): + os.system('rm -rf '+dir_eval_flowing) + os.makedirs(dir_eval_flowing) + else: + os.makedirs(dir_eval_flowing) + + + os.mkdir(dir_flow_train_imgs) + os.mkdir(dir_flow_train_labels) + + os.mkdir(dir_flow_eval_imgs) + os.mkdir(dir_flow_eval_labels) + + + + #set the gpu configuration + configuration() + + + #writing patches into a sub-folder in order to be flowed from directory. + provide_patches(dir_img,dir_seg,dir_flow_train_imgs, + dir_flow_train_labels, + input_height,input_width,blur_k,blur_aug, + flip_aug,binarization,scaling,scales,flip_index, + scaling_bluring,scaling_binarization,rotation, + augmentation=augmentation,patches=patches) + + provide_patches(dir_img_val,dir_seg_val,dir_flow_eval_imgs, + dir_flow_eval_labels, + input_height,input_width,blur_k,blur_aug, + flip_aug,binarization,scaling,scales,flip_index, + scaling_bluring,scaling_binarization,rotation, + augmentation=False,patches=patches) + + if weighted_loss: + weights=np.zeros(n_classes) + for obj in os.listdir(dir_seg): + label_obj=cv2.imread(dir_seg+'/'+obj) + label_obj_one_hot=get_one_hot( label_obj,label_obj.shape[0],label_obj.shape[1],n_classes) + weights+=(label_obj_one_hot.sum(axis=0)).sum(axis=0) + + + weights=1.00/weights + + weights=weights/float(np.sum(weights)) + weights=weights/float(np.min(weights)) + weights=weights/float(np.sum(weights)) + + + + + #get our model. + model = resnet50_unet(n_classes, input_height, input_width,weight_decay,pretraining) + + #if you want to see the model structure just uncomment model summary. + #model.summary() + + + if not weighted_loss: + model.compile(loss='categorical_crossentropy', + optimizer = Adam(lr=learning_rate),metrics=['accuracy']) + if weighted_loss: + model.compile(loss=weighted_categorical_crossentropy(weights), + optimizer = Adam(lr=learning_rate),metrics=['accuracy']) + + mc = keras.callbacks.ModelCheckpoint('weights{epoch:08d}.h5', + save_weights_only=True, period=1) + + + #generating train and evaluation data + train_gen = data_gen(dir_flow_train_imgs,dir_flow_train_labels, batch_size = n_batch, + input_height=input_height, input_width=input_width,n_classes=n_classes ) + val_gen = data_gen(dir_flow_eval_imgs,dir_flow_eval_labels, batch_size = n_batch, + input_height=input_height, input_width=input_width,n_classes=n_classes ) + + + model.fit_generator( + train_gen, + steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs))/n_batch), + validation_data=val_gen, + validation_steps=1, + epochs=n_epochs) + + + + os.system('rm -rf '+dir_train_flowing) + os.system('rm -rf '+dir_eval_flowing) + + model.save(dir_output+'/'+'model'+'.h5') + + + + + + + + + + diff --git a/train/utils.py b/train/utils.py new file mode 100644 index 0000000..afdc9e5 --- /dev/null +++ b/train/utils.py @@ -0,0 +1,336 @@ +import os +import cv2 +import numpy as np +import seaborn as sns +from scipy.ndimage.interpolation import map_coordinates +from scipy.ndimage.filters import gaussian_filter +import random +from tqdm import tqdm + + + + +def bluring(img_in,kind): + if kind=='guass': + img_blur = cv2.GaussianBlur(img_in,(5,5),0) + elif kind=="median": + img_blur = cv2.medianBlur(img_in,5) + elif kind=='blur': + img_blur=cv2.blur(img_in,(5,5)) + return img_blur + +def color_images(seg, n_classes): + ann_u=range(n_classes) + if len(np.shape(seg))==3: + seg=seg[:,:,0] + + seg_img=np.zeros((np.shape(seg)[0],np.shape(seg)[1],3)).astype(float) + colors=sns.color_palette("hls", n_classes) + + for c in ann_u: + c=int(c) + segl=(seg==c) + seg_img[:,:,0]+=segl*(colors[c][0]) + seg_img[:,:,1]+=segl*(colors[c][1]) + seg_img[:,:,2]+=segl*(colors[c][2]) + return seg_img + + +def resize_image(seg_in,input_height,input_width): + return cv2.resize(seg_in,(input_width,input_height),interpolation=cv2.INTER_NEAREST) +def get_one_hot(seg,input_height,input_width,n_classes): + seg=seg[:,:,0] + seg_f=np.zeros((input_height, input_width,n_classes)) + for j in range(n_classes): + seg_f[:,:,j]=(seg==j).astype(int) + return seg_f + + +def IoU(Yi,y_predi): + ## mean Intersection over Union + ## Mean IoU = TP/(FN + TP + FP) + + IoUs = [] + classes_true=np.unique(Yi) + for c in classes_true: + TP = np.sum( (Yi == c)&(y_predi==c) ) + FP = np.sum( (Yi != c)&(y_predi==c) ) + FN = np.sum( (Yi == c)&(y_predi != c)) + IoU = TP/float(TP + FP + FN) + print("class {:02.0f}: #TP={:6.0f}, #FP={:6.0f}, #FN={:5.0f}, IoU={:4.3f}".format(c,TP,FP,FN,IoU)) + IoUs.append(IoU) + mIoU = np.mean(IoUs) + print("_________________") + print("Mean IoU: {:4.3f}".format(mIoU)) + return mIoU +def data_gen(img_folder, mask_folder, batch_size,input_height, input_width,n_classes): + c = 0 + n = os.listdir(img_folder) #List of training images + random.shuffle(n) + while True: + img = np.zeros((batch_size, input_height, input_width, 3)).astype('float') + mask = np.zeros((batch_size, input_height, input_width, n_classes)).astype('float') + + for i in range(c, c+batch_size): #initially from 0 to 16, c = 0. + #print(img_folder+'/'+n[i]) + filename=n[i].split('.')[0] + train_img = cv2.imread(img_folder+'/'+n[i])/255. + train_img = cv2.resize(train_img, (input_width, input_height),interpolation=cv2.INTER_NEAREST)# Read an image from folder and resize + + img[i-c] = train_img #add to array - img[0], img[1], and so on. + train_mask = cv2.imread(mask_folder+'/'+filename+'.png') + #print(mask_folder+'/'+filename+'.png') + #print(train_mask.shape) + train_mask = get_one_hot( resize_image(train_mask,input_height,input_width),input_height,input_width,n_classes) + #train_mask = train_mask.reshape(224, 224, 1) # Add extra dimension for parity with train_img size [512 * 512 * 3] + + mask[i-c] = train_mask + + c+=batch_size + if(c+batch_size>=len(os.listdir(img_folder))): + c=0 + random.shuffle(n) + yield img, mask + +def otsu_copy(img): + img_r=np.zeros(img.shape) + img1=img[:,:,0] + img2=img[:,:,1] + img3=img[:,:,2] + _, threshold1 = cv2.threshold(img1, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) + _, threshold2 = cv2.threshold(img2, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) + _, threshold3 = cv2.threshold(img3, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) + img_r[:,:,0]=threshold1 + img_r[:,:,1]=threshold1 + img_r[:,:,2]=threshold1 + return img_r + +def rotation_90(img): + img_rot=np.zeros((img.shape[1],img.shape[0],img.shape[2])) + img_rot[:,:,0]=img[:,:,0].T + img_rot[:,:,1]=img[:,:,1].T + img_rot[:,:,2]=img[:,:,2].T + return img_rot + +def get_patches(dir_img_f,dir_seg_f,img,label,height,width,indexer): + + + img_h=img.shape[0] + img_w=img.shape[1] + + nxf=img_w/float(width) + nyf=img_h/float(height) + + if nxf>int(nxf): + nxf=int(nxf)+1 + if nyf>int(nyf): + nyf=int(nyf)+1 + + nxf=int(nxf) + nyf=int(nyf) + + for i in range(nxf): + for j in range(nyf): + index_x_d=i*width + index_x_u=(i+1)*width + + index_y_d=j*height + index_y_u=(j+1)*height + + if index_x_u>img_w: + index_x_u=img_w + index_x_d=img_w-width + if index_y_u>img_h: + index_y_u=img_h + index_y_d=img_h-height + + + img_patch=img[index_y_d:index_y_u,index_x_d:index_x_u,:] + label_patch=label[index_y_d:index_y_u,index_x_d:index_x_u,:] + + cv2.imwrite(dir_img_f+'/img_'+str(indexer)+'.png', img_patch ) + cv2.imwrite(dir_seg_f+'/img_'+str(indexer)+'.png' , label_patch ) + indexer+=1 + return indexer + + + +def get_patches_num_scale(dir_img_f,dir_seg_f,img,label,height,width,indexer,scaler): + + + img_h=img.shape[0] + img_w=img.shape[1] + + height_scale=int(height*scaler) + width_scale=int(width*scaler) + + + nxf=img_w/float(width_scale) + nyf=img_h/float(height_scale) + + if nxf>int(nxf): + nxf=int(nxf)+1 + if nyf>int(nyf): + nyf=int(nyf)+1 + + nxf=int(nxf) + nyf=int(nyf) + + for i in range(nxf): + for j in range(nyf): + index_x_d=i*width_scale + index_x_u=(i+1)*width_scale + + index_y_d=j*height_scale + index_y_u=(j+1)*height_scale + + if index_x_u>img_w: + index_x_u=img_w + index_x_d=img_w-width_scale + if index_y_u>img_h: + index_y_u=img_h + index_y_d=img_h-height_scale + + + img_patch=img[index_y_d:index_y_u,index_x_d:index_x_u,:] + label_patch=label[index_y_d:index_y_u,index_x_d:index_x_u,:] + + img_patch=resize_image(img_patch,height,width) + label_patch=resize_image(label_patch,height,width) + + cv2.imwrite(dir_img_f+'/img_'+str(indexer)+'.png', img_patch ) + cv2.imwrite(dir_seg_f+'/img_'+str(indexer)+'.png' , label_patch ) + indexer+=1 + + return indexer + + + +def provide_patches(dir_img,dir_seg,dir_flow_train_imgs, + dir_flow_train_labels, + input_height,input_width,blur_k,blur_aug, + flip_aug,binarization,scaling,scales,flip_index, + scaling_bluring,scaling_binarization,rotation, + augmentation=False,patches=False): + + imgs_cv_train=np.array(os.listdir(dir_img)) + segs_cv_train=np.array(os.listdir(dir_seg)) + + indexer=0 + for im, seg_i in tqdm(zip(imgs_cv_train,segs_cv_train)): + img_name=im.split('.')[0] + + if not patches: + cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', resize_image(cv2.imread(dir_img+'/'+im),input_height,input_width ) ) + cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png' , resize_image(cv2.imread(dir_seg+'/'+img_name+'.png'),input_height,input_width ) ) + indexer+=1 + + if augmentation: + if rotation: + cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', + rotation_90( resize_image(cv2.imread(dir_img+'/'+im), + input_height,input_width) ) ) + + + cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png', + rotation_90 ( resize_image(cv2.imread(dir_seg+'/'+img_name+'.png'), + input_height,input_width) ) ) + indexer+=1 + + if flip_aug: + for f_i in flip_index: + cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', + resize_image(cv2.flip(cv2.imread(dir_img+'/'+im),f_i),input_height,input_width) ) + + cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png' , + resize_image(cv2.flip(cv2.imread(dir_seg+'/'+img_name+'.png'),f_i),input_height,input_width) ) + indexer+=1 + + if blur_aug: + for blur_i in blur_k: + cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', + (resize_image(bluring(cv2.imread(dir_img+'/'+im),blur_i),input_height,input_width) ) ) + + cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png' , + resize_image(cv2.imread(dir_seg+'/'+img_name+'.png'),input_height,input_width) ) + indexer+=1 + + + if binarization: + cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', + resize_image(otsu_copy( cv2.imread(dir_img+'/'+im)),input_height,input_width )) + + cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png', + resize_image( cv2.imread(dir_seg+'/'+img_name+'.png'),input_height,input_width )) + indexer+=1 + + + + + + + if patches: + + indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, + cv2.imread(dir_img+'/'+im),cv2.imread(dir_seg+'/'+img_name+'.png'), + input_height,input_width,indexer=indexer) + + if augmentation: + + if rotation: + + + indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, + rotation_90( cv2.imread(dir_img+'/'+im) ), + rotation_90( cv2.imread(dir_seg+'/'+img_name+'.png') ), + input_height,input_width,indexer=indexer) + if flip_aug: + for f_i in flip_index: + + indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, + cv2.flip( cv2.imread(dir_img+'/'+im) , f_i), + cv2.flip( cv2.imread(dir_seg+'/'+img_name+'.png') ,f_i), + input_height,input_width,indexer=indexer) + if blur_aug: + for blur_i in blur_k: + indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, + bluring( cv2.imread(dir_img+'/'+im) , blur_i), + cv2.imread(dir_seg+'/'+img_name+'.png'), + input_height,input_width,indexer=indexer) + + + if scaling: + for sc_ind in scales: + indexer=get_patches_num_scale(dir_flow_train_imgs,dir_flow_train_labels, + cv2.imread(dir_img+'/'+im) , + cv2.imread(dir_seg+'/'+img_name+'.png'), + input_height,input_width,indexer=indexer,scaler=sc_ind) + if binarization: + + indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, + otsu_copy( cv2.imread(dir_img+'/'+im)), + cv2.imread(dir_seg+'/'+img_name+'.png'), + input_height,input_width,indexer=indexer) + + + + if scaling_bluring: + for sc_ind in scales: + for blur_i in blur_k: + indexer=get_patches_num_scale(dir_flow_train_imgs,dir_flow_train_labels, + bluring( cv2.imread(dir_img+'/'+im) , blur_i) , + cv2.imread(dir_seg+'/'+img_name+'.png') , + input_height,input_width,indexer=indexer,scaler=sc_ind) + + if scaling_binarization: + for sc_ind in scales: + indexer=get_patches_num_scale(dir_flow_train_imgs,dir_flow_train_labels, + otsu_copy( cv2.imread(dir_img+'/'+im)) , + cv2.imread(dir_seg+'/'+img_name+'.png'), + input_height,input_width,indexer=indexer,scaler=sc_ind) + + + + + + From 1882dd8f53b665993c806ff5587562772f65c8a7 Mon Sep 17 00:00:00 2001 From: "Rezanezhad, Vahid" Date: Thu, 5 Dec 2019 14:05:55 +0100 Subject: [PATCH 004/123] Update config_params.json --- train/config_params.json | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train/config_params.json b/train/config_params.json index 52db6db..5066444 100644 --- a/train/config_params.json +++ b/train/config_params.json @@ -18,7 +18,7 @@ "scaling_binarization" : false, "rotation": false, "weighted_loss": true, - "dir_train": "/home/vahid/textline_gt_images/train_light", - "dir_eval": "/home/vahid/textline_gt_images/eval", - "dir_output": "/home/vahid/textline_gt_images/output" + "dir_train": "../train", + "dir_eval": "../eval", + "dir_output": "../output" } From e8afb370bafa617250ef3f15fe35a721e0a1ccbd Mon Sep 17 00:00:00 2001 From: "Rezanezhad, Vahid" Date: Thu, 5 Dec 2019 14:08:08 +0100 Subject: [PATCH 005/123] Update README --- train/README | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/train/README b/train/README index 7d8d790..8d478bd 100644 --- a/train/README +++ b/train/README @@ -4,17 +4,20 @@ how to train: format of ground truth: - Lables for each pixel is identified by a number . So if you have a binary case n_classes should be set to 2 and labels should be 0 and 1 for each class and pixel. - In the case of multiclass just set n_classes to the number of classes you have and the try to produce the labels by pixels from 0 , 1 ,2 .., n_classes-1. + Lables for each pixel is identified by a number . So if you have a binary case n_classes should be set to 2 and + labels should be 0 and 1 for each class and pixel. + In the case of multiclass just set n_classes to the number of classes you have and the try to produce the labels + by pixels set from 0 , 1 ,2 .., n_classes-1. The labels format should be png. If you have an image label for binary case it should look like this: - Label: [ [[1 0 0 1], [1 0 0 1] ,[1 0 0 1]], [[1 0 0 1], [1 0 0 1] ,[1 0 0 1]] ,[[1 0 0 1], [1 0 0 1] ,[1 0 0 1]] ] this means that you have an image by 3*4*3 and pixel[0,0] belongs to class 1 and pixel[0,1] to class 0. + Label: [ [[1 0 0 1], [1 0 0 1] ,[1 0 0 1]], [[1 0 0 1], [1 0 0 1] ,[1 0 0 1]] ,[[1 0 0 1], [1 0 0 1] ,[1 0 0 1]] ] + this means that you have an image by 3*4*3 and pixel[0,0] belongs to class 1 and pixel[0,1] to class 0. -traing , evaluation and output: +training , evaluation and output: train and evaluation folder should have subfolder of images and labels. - And output folder should be free folder which the output model will be written there. + And output folder should be empty folder which the output model will be written there. patches: From 99a02a1bf55a8022110ca78d0363c2eae610cecf Mon Sep 17 00:00:00 2001 From: "Rezanezhad, Vahid" Date: Thu, 5 Dec 2019 14:11:37 +0100 Subject: [PATCH 006/123] Update README --- train/README | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/train/README b/train/README index 8d478bd..54ea408 100644 --- a/train/README +++ b/train/README @@ -21,6 +21,7 @@ training , evaluation and output: patches: - if you want to train your model with patches, the height and width of patches should be defined and also number of batchs (how many patches should be seen by model by each iteration). + if you want to train your model with patches, the height and width of patches should be defined and also number of + batchs (how many patches should be seen by model by each iteration). In the case that model should see the image once, like page extraction, the patches should be set to false. From 7eb3dd26addb0131cf39c6bdbf0dcd88ed61d8d5 Mon Sep 17 00:00:00 2001 From: "Rezanezhad, Vahid" Date: Thu, 5 Dec 2019 16:11:31 +0100 Subject: [PATCH 007/123] Update README --- train/README | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/train/README b/train/README index 54ea408..e103b0b 100644 --- a/train/README +++ b/train/README @@ -1,8 +1,8 @@ -how to train: +# Train just run: python train.py with config_params.json -format of ground truth: +# Ground truth format Lables for each pixel is identified by a number . So if you have a binary case n_classes should be set to 2 and labels should be 0 and 1 for each class and pixel. @@ -15,11 +15,11 @@ format of ground truth: Label: [ [[1 0 0 1], [1 0 0 1] ,[1 0 0 1]], [[1 0 0 1], [1 0 0 1] ,[1 0 0 1]] ,[[1 0 0 1], [1 0 0 1] ,[1 0 0 1]] ] this means that you have an image by 3*4*3 and pixel[0,0] belongs to class 1 and pixel[0,1] to class 0. -training , evaluation and output: +# Training , evaluation and output train and evaluation folder should have subfolder of images and labels. And output folder should be empty folder which the output model will be written there. -patches: +# Patches if you want to train your model with patches, the height and width of patches should be defined and also number of batchs (how many patches should be seen by model by each iteration). From cf18aa7fbb64900979b816b6b03ff20c5378b3a9 Mon Sep 17 00:00:00 2001 From: "Rezanezhad, Vahid" Date: Thu, 5 Dec 2019 16:13:37 +0100 Subject: [PATCH 008/123] Update README --- train/README | 27 +-------------------------- 1 file changed, 1 insertion(+), 26 deletions(-) diff --git a/train/README b/train/README index e103b0b..5237d53 100644 --- a/train/README +++ b/train/README @@ -1,27 +1,2 @@ -# Train - just run: python train.py with config_params.json - - -# Ground truth format - - Lables for each pixel is identified by a number . So if you have a binary case n_classes should be set to 2 and - labels should be 0 and 1 for each class and pixel. - In the case of multiclass just set n_classes to the number of classes you have and the try to produce the labels - by pixels set from 0 , 1 ,2 .., n_classes-1. - The labels format should be png. - - If you have an image label for binary case it should look like this: - - Label: [ [[1 0 0 1], [1 0 0 1] ,[1 0 0 1]], [[1 0 0 1], [1 0 0 1] ,[1 0 0 1]] ,[[1 0 0 1], [1 0 0 1] ,[1 0 0 1]] ] - this means that you have an image by 3*4*3 and pixel[0,0] belongs to class 1 and pixel[0,1] to class 0. - -# Training , evaluation and output - train and evaluation folder should have subfolder of images and labels. - And output folder should be empty folder which the output model will be written there. - -# Patches - - if you want to train your model with patches, the height and width of patches should be defined and also number of - batchs (how many patches should be seen by model by each iteration). - In the case that model should see the image once, like page extraction, the patches should be set to false. + From ac542665815bea97752440bcf874a21ec939c047 Mon Sep 17 00:00:00 2001 From: "Rezanezhad, Vahid" Date: Thu, 5 Dec 2019 16:13:40 +0100 Subject: [PATCH 009/123] Delete README --- train/README | 2 -- 1 file changed, 2 deletions(-) delete mode 100644 train/README diff --git a/train/README b/train/README deleted file mode 100644 index 5237d53..0000000 --- a/train/README +++ /dev/null @@ -1,2 +0,0 @@ - - From 350378af168d68f4709c1b98bc8e867e9b46ccfd Mon Sep 17 00:00:00 2001 From: "Rezanezhad, Vahid" Date: Thu, 5 Dec 2019 16:14:00 +0100 Subject: [PATCH 010/123] Add new file --- train/README.md | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 train/README.md diff --git a/train/README.md b/train/README.md new file mode 100644 index 0000000..c4dc27e --- /dev/null +++ b/train/README.md @@ -0,0 +1,26 @@ +# Train + just run: python train.py with config_params.json + + +# Ground truth format + + Lables for each pixel is identified by a number . So if you have a binary case n_classes should be set to 2 and + labels should be 0 and 1 for each class and pixel. + In the case of multiclass just set n_classes to the number of classes you have and the try to produce the labels + by pixels set from 0 , 1 ,2 .., n_classes-1. + The labels format should be png. + + If you have an image label for binary case it should look like this: + + Label: [ [[1 0 0 1], [1 0 0 1] ,[1 0 0 1]], [[1 0 0 1], [1 0 0 1] ,[1 0 0 1]] ,[[1 0 0 1], [1 0 0 1] ,[1 0 0 1]] ] + this means that you have an image by 3*4*3 and pixel[0,0] belongs to class 1 and pixel[0,1] to class 0. + +# Training , evaluation and output + train and evaluation folder should have subfolder of images and labels. + And output folder should be empty folder which the output model will be written there. + +# Patches + + if you want to train your model with patches, the height and width of patches should be defined and also number of + batchs (how many patches should be seen by model by each iteration). + In the case that model should see the image once, like page extraction, the patches should be set to false. \ No newline at end of file From 979b824aa8fe84619e9863372b45647ed8306327 Mon Sep 17 00:00:00 2001 From: "Gerber, Mike" Date: Mon, 9 Dec 2019 15:33:53 +0100 Subject: [PATCH 011/123] =?UTF-8?q?=F0=9F=93=9D=20howto:=20Be=20more=20ver?= =?UTF-8?q?bose=20with=20the=20subtree=20pull?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train/.gitkeep | 0 train/README.md | 26 +++ train/__init__.py | 0 train/config_params.json | 24 +++ train/metrics.py | 338 +++++++++++++++++++++++++++++++++++++++ train/models.py | 317 ++++++++++++++++++++++++++++++++++++ train/train.py | 192 ++++++++++++++++++++++ train/utils.py | 336 ++++++++++++++++++++++++++++++++++++++ 8 files changed, 1233 insertions(+) create mode 100644 train/.gitkeep create mode 100644 train/README.md create mode 100644 train/__init__.py create mode 100644 train/config_params.json create mode 100644 train/metrics.py create mode 100644 train/models.py create mode 100644 train/train.py create mode 100644 train/utils.py diff --git a/train/.gitkeep b/train/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/train/README.md b/train/README.md new file mode 100644 index 0000000..c4dc27e --- /dev/null +++ b/train/README.md @@ -0,0 +1,26 @@ +# Train + just run: python train.py with config_params.json + + +# Ground truth format + + Lables for each pixel is identified by a number . So if you have a binary case n_classes should be set to 2 and + labels should be 0 and 1 for each class and pixel. + In the case of multiclass just set n_classes to the number of classes you have and the try to produce the labels + by pixels set from 0 , 1 ,2 .., n_classes-1. + The labels format should be png. + + If you have an image label for binary case it should look like this: + + Label: [ [[1 0 0 1], [1 0 0 1] ,[1 0 0 1]], [[1 0 0 1], [1 0 0 1] ,[1 0 0 1]] ,[[1 0 0 1], [1 0 0 1] ,[1 0 0 1]] ] + this means that you have an image by 3*4*3 and pixel[0,0] belongs to class 1 and pixel[0,1] to class 0. + +# Training , evaluation and output + train and evaluation folder should have subfolder of images and labels. + And output folder should be empty folder which the output model will be written there. + +# Patches + + if you want to train your model with patches, the height and width of patches should be defined and also number of + batchs (how many patches should be seen by model by each iteration). + In the case that model should see the image once, like page extraction, the patches should be set to false. \ No newline at end of file diff --git a/train/__init__.py b/train/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/train/config_params.json b/train/config_params.json new file mode 100644 index 0000000..5066444 --- /dev/null +++ b/train/config_params.json @@ -0,0 +1,24 @@ +{ + "n_classes" : 2, + "n_epochs" : 2, + "input_height" : 448, + "input_width" : 896, + "weight_decay" : 1e-6, + "n_batch" : 1, + "learning_rate": 1e-4, + "patches" : true, + "pretraining" : true, + "augmentation" : false, + "flip_aug" : false, + "elastic_aug" : false, + "blur_aug" : false, + "scaling" : false, + "binarization" : false, + "scaling_bluring" : false, + "scaling_binarization" : false, + "rotation": false, + "weighted_loss": true, + "dir_train": "../train", + "dir_eval": "../eval", + "dir_output": "../output" +} diff --git a/train/metrics.py b/train/metrics.py new file mode 100644 index 0000000..c63cc22 --- /dev/null +++ b/train/metrics.py @@ -0,0 +1,338 @@ +from keras import backend as K +import tensorflow as tf +import numpy as np + +def focal_loss(gamma=2., alpha=4.): + + gamma = float(gamma) + alpha = float(alpha) + + def focal_loss_fixed(y_true, y_pred): + """Focal loss for multi-classification + FL(p_t)=-alpha(1-p_t)^{gamma}ln(p_t) + Notice: y_pred is probability after softmax + gradient is d(Fl)/d(p_t) not d(Fl)/d(x) as described in paper + d(Fl)/d(p_t) * [p_t(1-p_t)] = d(Fl)/d(x) + Focal Loss for Dense Object Detection + https://arxiv.org/abs/1708.02002 + + Arguments: + y_true {tensor} -- ground truth labels, shape of [batch_size, num_cls] + y_pred {tensor} -- model's output, shape of [batch_size, num_cls] + + Keyword Arguments: + gamma {float} -- (default: {2.0}) + alpha {float} -- (default: {4.0}) + + Returns: + [tensor] -- loss. + """ + epsilon = 1.e-9 + y_true = tf.convert_to_tensor(y_true, tf.float32) + y_pred = tf.convert_to_tensor(y_pred, tf.float32) + + model_out = tf.add(y_pred, epsilon) + ce = tf.multiply(y_true, -tf.log(model_out)) + weight = tf.multiply(y_true, tf.pow(tf.subtract(1., model_out), gamma)) + fl = tf.multiply(alpha, tf.multiply(weight, ce)) + reduced_fl = tf.reduce_max(fl, axis=1) + return tf.reduce_mean(reduced_fl) + return focal_loss_fixed + +def weighted_categorical_crossentropy(weights=None): + """ weighted_categorical_crossentropy + + Args: + * weights: crossentropy weights + Returns: + * weighted categorical crossentropy function + """ + + def loss(y_true, y_pred): + labels_floats = tf.cast(y_true, tf.float32) + per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred) + + if weights is not None: + weight_mask = tf.maximum(tf.reduce_max(tf.constant( + np.array(weights, dtype=np.float32)[None, None, None]) + * labels_floats, axis=-1), 1.0) + per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None] + return tf.reduce_mean(per_pixel_loss) + return loss +def image_categorical_cross_entropy(y_true, y_pred, weights=None): + """ + :param y_true: tensor of shape (batch_size, height, width) representing the ground truth. + :param y_pred: tensor of shape (batch_size, height, width) representing the prediction. + :return: The mean cross-entropy on softmaxed tensors. + """ + + labels_floats = tf.cast(y_true, tf.float32) + per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred) + + if weights is not None: + weight_mask = tf.maximum( + tf.reduce_max(tf.constant( + np.array(weights, dtype=np.float32)[None, None, None]) + * labels_floats, axis=-1), 1.0) + per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None] + + return tf.reduce_mean(per_pixel_loss) +def class_tversky(y_true, y_pred): + smooth = 1.0#1.00 + + y_true = K.permute_dimensions(y_true, (3,1,2,0)) + y_pred = K.permute_dimensions(y_pred, (3,1,2,0)) + + y_true_pos = K.batch_flatten(y_true) + y_pred_pos = K.batch_flatten(y_pred) + true_pos = K.sum(y_true_pos * y_pred_pos, 1) + false_neg = K.sum(y_true_pos * (1-y_pred_pos), 1) + false_pos = K.sum((1-y_true_pos)*y_pred_pos, 1) + alpha = 0.2#0.5 + beta=0.8 + return (true_pos + smooth)/(true_pos + alpha*false_neg + (beta)*false_pos + smooth) + +def focal_tversky_loss(y_true,y_pred): + pt_1 = class_tversky(y_true, y_pred) + gamma =1.3#4./3.0#1.3#4.0/3.00# 0.75 + return K.sum(K.pow((1-pt_1), gamma)) + +def generalized_dice_coeff2(y_true, y_pred): + n_el = 1 + for dim in y_true.shape: + n_el *= int(dim) + n_cl = y_true.shape[-1] + w = K.zeros(shape=(n_cl,)) + w = (K.sum(y_true, axis=(0,1,2)))/(n_el) + w = 1/(w**2+0.000001) + numerator = y_true*y_pred + numerator = w*K.sum(numerator,(0,1,2)) + numerator = K.sum(numerator) + denominator = y_true+y_pred + denominator = w*K.sum(denominator,(0,1,2)) + denominator = K.sum(denominator) + return 2*numerator/denominator +def generalized_dice_coeff(y_true, y_pred): + axes = tuple(range(1, len(y_pred.shape)-1)) + Ncl = y_pred.shape[-1] + w = K.zeros(shape=(Ncl,)) + w = K.sum(y_true, axis=axes) + w = 1/(w**2+0.000001) + # Compute gen dice coef: + numerator = y_true*y_pred + numerator = w*K.sum(numerator,axes) + numerator = K.sum(numerator) + + denominator = y_true+y_pred + denominator = w*K.sum(denominator,axes) + denominator = K.sum(denominator) + + gen_dice_coef = 2*numerator/denominator + + return gen_dice_coef + +def generalized_dice_loss(y_true, y_pred): + return 1 - generalized_dice_coeff2(y_true, y_pred) +def soft_dice_loss(y_true, y_pred, epsilon=1e-6): + ''' + Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions. + Assumes the `channels_last` format. + + # Arguments + y_true: b x X x Y( x Z...) x c One hot encoding of ground truth + y_pred: b x X x Y( x Z...) x c Network output, must sum to 1 over c channel (such as after softmax) + epsilon: Used for numerical stability to avoid divide by zero errors + + # References + V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation + https://arxiv.org/abs/1606.04797 + More details on Dice loss formulation + https://mediatum.ub.tum.de/doc/1395260/1395260.pdf (page 72) + + Adapted from https://github.com/Lasagne/Recipes/issues/99#issuecomment-347775022 + ''' + + # skip the batch and class axis for calculating Dice score + axes = tuple(range(1, len(y_pred.shape)-1)) + + numerator = 2. * K.sum(y_pred * y_true, axes) + + denominator = K.sum(K.square(y_pred) + K.square(y_true), axes) + return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch + +def seg_metrics(y_true, y_pred, metric_name, metric_type='standard', drop_last = True, mean_per_class=False, verbose=False): + """ + Compute mean metrics of two segmentation masks, via Keras. + + IoU(A,B) = |A & B| / (| A U B|) + Dice(A,B) = 2*|A & B| / (|A| + |B|) + + Args: + y_true: true masks, one-hot encoded. + y_pred: predicted masks, either softmax outputs, or one-hot encoded. + metric_name: metric to be computed, either 'iou' or 'dice'. + metric_type: one of 'standard' (default), 'soft', 'naive'. + In the standard version, y_pred is one-hot encoded and the mean + is taken only over classes that are present (in y_true or y_pred). + The 'soft' version of the metrics are computed without one-hot + encoding y_pred. + The 'naive' version return mean metrics where absent classes contribute + to the class mean as 1.0 (instead of being dropped from the mean). + drop_last = True: boolean flag to drop last class (usually reserved + for background class in semantic segmentation) + mean_per_class = False: return mean along batch axis for each class. + verbose = False: print intermediate results such as intersection, union + (as number of pixels). + Returns: + IoU/Dice of y_true and y_pred, as a float, unless mean_per_class == True + in which case it returns the per-class metric, averaged over the batch. + + Inputs are B*W*H*N tensors, with + B = batch size, + W = width, + H = height, + N = number of classes + """ + + flag_soft = (metric_type == 'soft') + flag_naive_mean = (metric_type == 'naive') + + # always assume one or more classes + num_classes = K.shape(y_true)[-1] + + if not flag_soft: + # get one-hot encoded masks from y_pred (true masks should already be one-hot) + y_pred = K.one_hot(K.argmax(y_pred), num_classes) + y_true = K.one_hot(K.argmax(y_true), num_classes) + + # if already one-hot, could have skipped above command + # keras uses float32 instead of float64, would give error down (but numpy arrays or keras.to_categorical gives float64) + y_true = K.cast(y_true, 'float32') + y_pred = K.cast(y_pred, 'float32') + + # intersection and union shapes are batch_size * n_classes (values = area in pixels) + axes = (1,2) # W,H axes of each image + intersection = K.sum(K.abs(y_true * y_pred), axis=axes) + mask_sum = K.sum(K.abs(y_true), axis=axes) + K.sum(K.abs(y_pred), axis=axes) + union = mask_sum - intersection # or, np.logical_or(y_pred, y_true) for one-hot + + smooth = .001 + iou = (intersection + smooth) / (union + smooth) + dice = 2 * (intersection + smooth)/(mask_sum + smooth) + + metric = {'iou': iou, 'dice': dice}[metric_name] + + # define mask to be 0 when no pixels are present in either y_true or y_pred, 1 otherwise + mask = K.cast(K.not_equal(union, 0), 'float32') + + if drop_last: + metric = metric[:,:-1] + mask = mask[:,:-1] + + if verbose: + print('intersection, union') + print(K.eval(intersection), K.eval(union)) + print(K.eval(intersection/union)) + + # return mean metrics: remaining axes are (batch, classes) + if flag_naive_mean: + return K.mean(metric) + + # take mean only over non-absent classes + class_count = K.sum(mask, axis=0) + non_zero = tf.greater(class_count, 0) + non_zero_sum = tf.boolean_mask(K.sum(metric * mask, axis=0), non_zero) + non_zero_count = tf.boolean_mask(class_count, non_zero) + + if verbose: + print('Counts of inputs with class present, metrics for non-absent classes') + print(K.eval(class_count), K.eval(non_zero_sum / non_zero_count)) + + return K.mean(non_zero_sum / non_zero_count) + +def mean_iou(y_true, y_pred, **kwargs): + """ + Compute mean Intersection over Union of two segmentation masks, via Keras. + + Calls metrics_k(y_true, y_pred, metric_name='iou'), see there for allowed kwargs. + """ + return seg_metrics(y_true, y_pred, metric_name='iou', **kwargs) +def Mean_IOU(y_true, y_pred): + nb_classes = K.int_shape(y_pred)[-1] + iou = [] + true_pixels = K.argmax(y_true, axis=-1) + pred_pixels = K.argmax(y_pred, axis=-1) + void_labels = K.equal(K.sum(y_true, axis=-1), 0) + for i in range(0, nb_classes): # exclude first label (background) and last label (void) + true_labels = K.equal(true_pixels, i)# & ~void_labels + pred_labels = K.equal(pred_pixels, i)# & ~void_labels + inter = tf.to_int32(true_labels & pred_labels) + union = tf.to_int32(true_labels | pred_labels) + legal_batches = K.sum(tf.to_int32(true_labels), axis=1)>0 + ious = K.sum(inter, axis=1)/K.sum(union, axis=1) + iou.append(K.mean(tf.gather(ious, indices=tf.where(legal_batches)))) # returns average IoU of the same objects + iou = tf.stack(iou) + legal_labels = ~tf.debugging.is_nan(iou) + iou = tf.gather(iou, indices=tf.where(legal_labels)) + return K.mean(iou) + +def iou_vahid(y_true, y_pred): + nb_classes = tf.shape(y_true)[-1]+tf.to_int32(1) + true_pixels = K.argmax(y_true, axis=-1) + pred_pixels = K.argmax(y_pred, axis=-1) + iou = [] + + for i in tf.range(nb_classes): + tp=K.sum( tf.to_int32( K.equal(true_pixels, i) & K.equal(pred_pixels, i) ) ) + fp=K.sum( tf.to_int32( K.not_equal(true_pixels, i) & K.equal(pred_pixels, i) ) ) + fn=K.sum( tf.to_int32( K.equal(true_pixels, i) & K.not_equal(pred_pixels, i) ) ) + iouh=tp/(tp+fp+fn) + iou.append(iouh) + return K.mean(iou) + + +def IoU_metric(Yi,y_predi): + ## mean Intersection over Union + ## Mean IoU = TP/(FN + TP + FP) + y_predi = np.argmax(y_predi, axis=3) + y_testi = np.argmax(Yi, axis=3) + IoUs = [] + Nclass = int(np.max(Yi)) + 1 + for c in range(Nclass): + TP = np.sum( (Yi == c)&(y_predi==c) ) + FP = np.sum( (Yi != c)&(y_predi==c) ) + FN = np.sum( (Yi == c)&(y_predi != c)) + IoU = TP/float(TP + FP + FN) + IoUs.append(IoU) + return K.cast( np.mean(IoUs) ,dtype='float32' ) + + +def IoU_metric_keras(y_true, y_pred): + ## mean Intersection over Union + ## Mean IoU = TP/(FN + TP + FP) + init = tf.global_variables_initializer() + sess = tf.Session() + sess.run(init) + + return IoU_metric(y_true.eval(session=sess), y_pred.eval(session=sess)) + +def jaccard_distance_loss(y_true, y_pred, smooth=100): + """ + Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|) + = sum(|A*B|)/(sum(|A|)+sum(|B|)-sum(|A*B|)) + + The jaccard distance loss is usefull for unbalanced datasets. This has been + shifted so it converges on 0 and is smoothed to avoid exploding or disapearing + gradient. + + Ref: https://en.wikipedia.org/wiki/Jaccard_index + + @url: https://gist.github.com/wassname/f1452b748efcbeb4cb9b1d059dce6f96 + @author: wassname + """ + intersection = K.sum(K.abs(y_true * y_pred), axis=-1) + sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1) + jac = (intersection + smooth) / (sum_ - intersection + smooth) + return (1 - jac) * smooth + + diff --git a/train/models.py b/train/models.py new file mode 100644 index 0000000..7c806b4 --- /dev/null +++ b/train/models.py @@ -0,0 +1,317 @@ +from keras.models import * +from keras.layers import * +from keras import layers +from keras.regularizers import l2 + +resnet50_Weights_path='./pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5' +IMAGE_ORDERING ='channels_last' +MERGE_AXIS=-1 + + +def one_side_pad( x ): + x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x) + if IMAGE_ORDERING == 'channels_first': + x = Lambda(lambda x : x[: , : , :-1 , :-1 ] )(x) + elif IMAGE_ORDERING == 'channels_last': + x = Lambda(lambda x : x[: , :-1 , :-1 , : ] )(x) + return x + +def identity_block(input_tensor, kernel_size, filters, stage, block): + """The identity block is the block that has no conv layer at shortcut. + # Arguments + input_tensor: input tensor + kernel_size: defualt 3, the kernel size of middle conv layer at main path + filters: list of integers, the filterss of 3 conv layer at main path + stage: integer, current stage label, used for generating layer names + block: 'a','b'..., current block label, used for generating layer names + # Returns + Output tensor for the block. + """ + filters1, filters2, filters3 = filters + + if IMAGE_ORDERING == 'channels_last': + bn_axis = 3 + else: + bn_axis = 1 + + conv_name_base = 'res' + str(stage) + block + '_branch' + bn_name_base = 'bn' + str(stage) + block + '_branch' + + x = Conv2D(filters1, (1, 1) , data_format=IMAGE_ORDERING , name=conv_name_base + '2a')(input_tensor) + x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) + x = Activation('relu')(x) + + x = Conv2D(filters2, kernel_size , data_format=IMAGE_ORDERING , + padding='same', name=conv_name_base + '2b')(x) + x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) + x = Activation('relu')(x) + + x = Conv2D(filters3 , (1, 1), data_format=IMAGE_ORDERING , name=conv_name_base + '2c')(x) + x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) + + x = layers.add([x, input_tensor]) + x = Activation('relu')(x) + return x + + +def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)): + """conv_block is the block that has a conv layer at shortcut + # Arguments + input_tensor: input tensor + kernel_size: defualt 3, the kernel size of middle conv layer at main path + filters: list of integers, the filterss of 3 conv layer at main path + stage: integer, current stage label, used for generating layer names + block: 'a','b'..., current block label, used for generating layer names + # Returns + Output tensor for the block. + Note that from stage 3, the first conv layer at main path is with strides=(2,2) + And the shortcut should have strides=(2,2) as well + """ + filters1, filters2, filters3 = filters + + if IMAGE_ORDERING == 'channels_last': + bn_axis = 3 + else: + bn_axis = 1 + + conv_name_base = 'res' + str(stage) + block + '_branch' + bn_name_base = 'bn' + str(stage) + block + '_branch' + + x = Conv2D(filters1, (1, 1) , data_format=IMAGE_ORDERING , strides=strides, + name=conv_name_base + '2a')(input_tensor) + x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) + x = Activation('relu')(x) + + x = Conv2D(filters2, kernel_size , data_format=IMAGE_ORDERING , padding='same', + name=conv_name_base + '2b')(x) + x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) + x = Activation('relu')(x) + + x = Conv2D(filters3, (1, 1) , data_format=IMAGE_ORDERING , name=conv_name_base + '2c')(x) + x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) + + shortcut = Conv2D(filters3, (1, 1) , data_format=IMAGE_ORDERING , strides=strides, + name=conv_name_base + '1')(input_tensor) + shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut) + + x = layers.add([x, shortcut]) + x = Activation('relu')(x) + return x + + +def resnet50_unet_light(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False): + assert input_height%32 == 0 + assert input_width%32 == 0 + + + img_input = Input(shape=(input_height,input_width , 3 )) + + if IMAGE_ORDERING == 'channels_last': + bn_axis = 3 + else: + bn_axis = 1 + + x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input) + x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x) + f1 = x + + x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) + x = Activation('relu')(x) + x = MaxPooling2D((3, 3) , data_format=IMAGE_ORDERING , strides=(2, 2))(x) + + + x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) + x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') + x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') + f2 = one_side_pad(x ) + + + x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') + f3 = x + + x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') + f4 = x + + x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') + f5 = x + + + if pretraining: + model=Model( img_input , x ).load_weights(resnet50_Weights_path) + + + v512_2048 = Conv2D( 512 , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( f5 ) + v512_2048 = ( BatchNormalization(axis=bn_axis))(v512_2048) + v512_2048 = Activation('relu')(v512_2048) + + + + v512_1024=Conv2D( 512 , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( f4 ) + v512_1024 = ( BatchNormalization(axis=bn_axis))(v512_1024) + v512_1024 = Activation('relu')(v512_1024) + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(v512_2048) + o = ( concatenate([ o ,v512_1024],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D( (1,1), data_format=IMAGE_ORDERING))(o) + o = ( Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) + o = ( concatenate([ o ,f3],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D( (1,1), data_format=IMAGE_ORDERING))(o) + o = ( Conv2D( 256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) + o = ( concatenate([o,f2],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING))(o) + o = ( Conv2D( 128 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay) ) )(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) + o = ( concatenate([o,f1],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING ))(o) + o = ( Conv2D( 64 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) ))(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) + o = ( concatenate([o,img_input],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING ))(o) + o = ( Conv2D( 32 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) ))(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + + o = Conv2D( n_classes , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( o ) + o = ( BatchNormalization(axis=bn_axis))(o) + o = (Activation('softmax'))(o) + + + model = Model( img_input , o ) + return model + +def resnet50_unet(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False): + assert input_height%32 == 0 + assert input_width%32 == 0 + + + img_input = Input(shape=(input_height,input_width , 3 )) + + if IMAGE_ORDERING == 'channels_last': + bn_axis = 3 + else: + bn_axis = 1 + + x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input) + x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x) + f1 = x + + x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) + x = Activation('relu')(x) + x = MaxPooling2D((3, 3) , data_format=IMAGE_ORDERING , strides=(2, 2))(x) + + + x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) + x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') + x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') + f2 = one_side_pad(x ) + + + x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') + f3 = x + + x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') + f4 = x + + x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') + f5 = x + + if pretraining: + Model( img_input , x ).load_weights(resnet50_Weights_path) + + v1024_2048 = Conv2D( 1024 , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( f5 ) + v1024_2048 = ( BatchNormalization(axis=bn_axis))(v1024_2048) + v1024_2048 = Activation('relu')(v1024_2048) + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(v1024_2048) + o = ( concatenate([ o ,f4],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D( (1,1), data_format=IMAGE_ORDERING))(o) + o = ( Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) + o = ( concatenate([ o ,f3],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D( (1,1), data_format=IMAGE_ORDERING))(o) + o = ( Conv2D( 256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) + o = ( concatenate([o,f2],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING))(o) + o = ( Conv2D( 128 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay) ) )(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) + o = ( concatenate([o,f1],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING ))(o) + o = ( Conv2D( 64 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) ))(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) + o = ( concatenate([o,img_input],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING ))(o) + o = ( Conv2D( 32 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) ))(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + o = Conv2D( n_classes , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( o ) + o = ( BatchNormalization(axis=bn_axis))(o) + o = (Activation('softmax'))(o) + + model = Model( img_input , o ) + + + + + return model diff --git a/train/train.py b/train/train.py new file mode 100644 index 0000000..07c7418 --- /dev/null +++ b/train/train.py @@ -0,0 +1,192 @@ +import os +import sys +import tensorflow as tf +from keras.backend.tensorflow_backend import set_session +import keras , warnings +from keras.optimizers import * +from sacred import Experiment +from models import * +from utils import * +from metrics import * + + +def configuration(): + keras.backend.clear_session() + tf.reset_default_graph() + warnings.filterwarnings('ignore') + + os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID' + config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True) + + + config.gpu_options.allow_growth = True + config.gpu_options.per_process_gpu_memory_fraction=0.95#0.95 + config.gpu_options.visible_device_list="0" + set_session(tf.Session(config=config)) + +def get_dirs_or_files(input_data): + if os.path.isdir(input_data): + image_input, labels_input = os.path.join(input_data, 'images/'), os.path.join(input_data, 'labels/') + # Check if training dir exists + assert os.path.isdir(image_input), "{} is not a directory".format(image_input) + assert os.path.isdir(labels_input), "{} is not a directory".format(labels_input) + return image_input, labels_input + +ex = Experiment() + +@ex.config +def config_params(): + n_classes=None # Number of classes. If your case study is binary case the set it to 2 and otherwise give your number of cases. + n_epochs=1 + input_height=224*1 + input_width=224*1 + weight_decay=1e-6 # Weight decay of l2 regularization of model layers. + n_batch=1 # Number of batches at each iteration. + learning_rate=1e-4 + patches=False # Make patches of image in order to use all information of image. In the case of page + # extraction this should be set to false since model should see all image. + augmentation=False + flip_aug=False # Flip image (augmentation). + elastic_aug=False # Elastic transformation (augmentation). + blur_aug=False # Blur patches of image (augmentation). + scaling=False # Scaling of patches (augmentation) will be imposed if this set to true. + binarization=False # Otsu thresholding. Used for augmentation in the case of binary case like textline prediction. For multicases should not be applied. + dir_train=None # Directory of training dataset (sub-folders should be named images and labels). + dir_eval=None # Directory of validation dataset (sub-folders should be named images and labels). + dir_output=None # Directory of output where the model should be saved. + pretraining=False # Set true to load pretrained weights of resnet50 encoder. + weighted_loss=False # Set True if classes are unbalanced and you want to use weighted loss function. + scaling_bluring=False + rotation: False + scaling_binarization=False + blur_k=['blur','guass','median'] # Used in order to blur image. Used for augmentation. + scales=[0.9 , 1.1 ] # Scale patches with these scales. Used for augmentation. + flip_index=[0,1] # Flip image. Used for augmentation. + + +@ex.automain +def run(n_classes,n_epochs,input_height, + input_width,weight_decay,weighted_loss, + n_batch,patches,augmentation,flip_aug,blur_aug,scaling, binarization, + blur_k,scales,dir_train, + scaling_bluring,scaling_binarization,rotation, + flip_index,dir_eval ,dir_output,pretraining,learning_rate): + + dir_img,dir_seg=get_dirs_or_files(dir_train) + dir_img_val,dir_seg_val=get_dirs_or_files(dir_eval) + + # make first a directory in output for both training and evaluations in order to flow data from these directories. + dir_train_flowing=os.path.join(dir_output,'train') + dir_eval_flowing=os.path.join(dir_output,'eval') + + dir_flow_train_imgs=os.path.join(dir_train_flowing,'images') + dir_flow_train_labels=os.path.join(dir_train_flowing,'labels') + + dir_flow_eval_imgs=os.path.join(dir_eval_flowing,'images') + dir_flow_eval_labels=os.path.join(dir_eval_flowing,'labels') + + if os.path.isdir(dir_train_flowing): + os.system('rm -rf '+dir_train_flowing) + os.makedirs(dir_train_flowing) + else: + os.makedirs(dir_train_flowing) + + if os.path.isdir(dir_eval_flowing): + os.system('rm -rf '+dir_eval_flowing) + os.makedirs(dir_eval_flowing) + else: + os.makedirs(dir_eval_flowing) + + + os.mkdir(dir_flow_train_imgs) + os.mkdir(dir_flow_train_labels) + + os.mkdir(dir_flow_eval_imgs) + os.mkdir(dir_flow_eval_labels) + + + + #set the gpu configuration + configuration() + + + #writing patches into a sub-folder in order to be flowed from directory. + provide_patches(dir_img,dir_seg,dir_flow_train_imgs, + dir_flow_train_labels, + input_height,input_width,blur_k,blur_aug, + flip_aug,binarization,scaling,scales,flip_index, + scaling_bluring,scaling_binarization,rotation, + augmentation=augmentation,patches=patches) + + provide_patches(dir_img_val,dir_seg_val,dir_flow_eval_imgs, + dir_flow_eval_labels, + input_height,input_width,blur_k,blur_aug, + flip_aug,binarization,scaling,scales,flip_index, + scaling_bluring,scaling_binarization,rotation, + augmentation=False,patches=patches) + + if weighted_loss: + weights=np.zeros(n_classes) + for obj in os.listdir(dir_seg): + label_obj=cv2.imread(dir_seg+'/'+obj) + label_obj_one_hot=get_one_hot( label_obj,label_obj.shape[0],label_obj.shape[1],n_classes) + weights+=(label_obj_one_hot.sum(axis=0)).sum(axis=0) + + + weights=1.00/weights + + weights=weights/float(np.sum(weights)) + weights=weights/float(np.min(weights)) + weights=weights/float(np.sum(weights)) + + + + + #get our model. + model = resnet50_unet(n_classes, input_height, input_width,weight_decay,pretraining) + + #if you want to see the model structure just uncomment model summary. + #model.summary() + + + if not weighted_loss: + model.compile(loss='categorical_crossentropy', + optimizer = Adam(lr=learning_rate),metrics=['accuracy']) + if weighted_loss: + model.compile(loss=weighted_categorical_crossentropy(weights), + optimizer = Adam(lr=learning_rate),metrics=['accuracy']) + + mc = keras.callbacks.ModelCheckpoint('weights{epoch:08d}.h5', + save_weights_only=True, period=1) + + + #generating train and evaluation data + train_gen = data_gen(dir_flow_train_imgs,dir_flow_train_labels, batch_size = n_batch, + input_height=input_height, input_width=input_width,n_classes=n_classes ) + val_gen = data_gen(dir_flow_eval_imgs,dir_flow_eval_labels, batch_size = n_batch, + input_height=input_height, input_width=input_width,n_classes=n_classes ) + + + model.fit_generator( + train_gen, + steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs))/n_batch), + validation_data=val_gen, + validation_steps=1, + epochs=n_epochs) + + + + os.system('rm -rf '+dir_train_flowing) + os.system('rm -rf '+dir_eval_flowing) + + model.save(dir_output+'/'+'model'+'.h5') + + + + + + + + + + diff --git a/train/utils.py b/train/utils.py new file mode 100644 index 0000000..afdc9e5 --- /dev/null +++ b/train/utils.py @@ -0,0 +1,336 @@ +import os +import cv2 +import numpy as np +import seaborn as sns +from scipy.ndimage.interpolation import map_coordinates +from scipy.ndimage.filters import gaussian_filter +import random +from tqdm import tqdm + + + + +def bluring(img_in,kind): + if kind=='guass': + img_blur = cv2.GaussianBlur(img_in,(5,5),0) + elif kind=="median": + img_blur = cv2.medianBlur(img_in,5) + elif kind=='blur': + img_blur=cv2.blur(img_in,(5,5)) + return img_blur + +def color_images(seg, n_classes): + ann_u=range(n_classes) + if len(np.shape(seg))==3: + seg=seg[:,:,0] + + seg_img=np.zeros((np.shape(seg)[0],np.shape(seg)[1],3)).astype(float) + colors=sns.color_palette("hls", n_classes) + + for c in ann_u: + c=int(c) + segl=(seg==c) + seg_img[:,:,0]+=segl*(colors[c][0]) + seg_img[:,:,1]+=segl*(colors[c][1]) + seg_img[:,:,2]+=segl*(colors[c][2]) + return seg_img + + +def resize_image(seg_in,input_height,input_width): + return cv2.resize(seg_in,(input_width,input_height),interpolation=cv2.INTER_NEAREST) +def get_one_hot(seg,input_height,input_width,n_classes): + seg=seg[:,:,0] + seg_f=np.zeros((input_height, input_width,n_classes)) + for j in range(n_classes): + seg_f[:,:,j]=(seg==j).astype(int) + return seg_f + + +def IoU(Yi,y_predi): + ## mean Intersection over Union + ## Mean IoU = TP/(FN + TP + FP) + + IoUs = [] + classes_true=np.unique(Yi) + for c in classes_true: + TP = np.sum( (Yi == c)&(y_predi==c) ) + FP = np.sum( (Yi != c)&(y_predi==c) ) + FN = np.sum( (Yi == c)&(y_predi != c)) + IoU = TP/float(TP + FP + FN) + print("class {:02.0f}: #TP={:6.0f}, #FP={:6.0f}, #FN={:5.0f}, IoU={:4.3f}".format(c,TP,FP,FN,IoU)) + IoUs.append(IoU) + mIoU = np.mean(IoUs) + print("_________________") + print("Mean IoU: {:4.3f}".format(mIoU)) + return mIoU +def data_gen(img_folder, mask_folder, batch_size,input_height, input_width,n_classes): + c = 0 + n = os.listdir(img_folder) #List of training images + random.shuffle(n) + while True: + img = np.zeros((batch_size, input_height, input_width, 3)).astype('float') + mask = np.zeros((batch_size, input_height, input_width, n_classes)).astype('float') + + for i in range(c, c+batch_size): #initially from 0 to 16, c = 0. + #print(img_folder+'/'+n[i]) + filename=n[i].split('.')[0] + train_img = cv2.imread(img_folder+'/'+n[i])/255. + train_img = cv2.resize(train_img, (input_width, input_height),interpolation=cv2.INTER_NEAREST)# Read an image from folder and resize + + img[i-c] = train_img #add to array - img[0], img[1], and so on. + train_mask = cv2.imread(mask_folder+'/'+filename+'.png') + #print(mask_folder+'/'+filename+'.png') + #print(train_mask.shape) + train_mask = get_one_hot( resize_image(train_mask,input_height,input_width),input_height,input_width,n_classes) + #train_mask = train_mask.reshape(224, 224, 1) # Add extra dimension for parity with train_img size [512 * 512 * 3] + + mask[i-c] = train_mask + + c+=batch_size + if(c+batch_size>=len(os.listdir(img_folder))): + c=0 + random.shuffle(n) + yield img, mask + +def otsu_copy(img): + img_r=np.zeros(img.shape) + img1=img[:,:,0] + img2=img[:,:,1] + img3=img[:,:,2] + _, threshold1 = cv2.threshold(img1, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) + _, threshold2 = cv2.threshold(img2, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) + _, threshold3 = cv2.threshold(img3, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) + img_r[:,:,0]=threshold1 + img_r[:,:,1]=threshold1 + img_r[:,:,2]=threshold1 + return img_r + +def rotation_90(img): + img_rot=np.zeros((img.shape[1],img.shape[0],img.shape[2])) + img_rot[:,:,0]=img[:,:,0].T + img_rot[:,:,1]=img[:,:,1].T + img_rot[:,:,2]=img[:,:,2].T + return img_rot + +def get_patches(dir_img_f,dir_seg_f,img,label,height,width,indexer): + + + img_h=img.shape[0] + img_w=img.shape[1] + + nxf=img_w/float(width) + nyf=img_h/float(height) + + if nxf>int(nxf): + nxf=int(nxf)+1 + if nyf>int(nyf): + nyf=int(nyf)+1 + + nxf=int(nxf) + nyf=int(nyf) + + for i in range(nxf): + for j in range(nyf): + index_x_d=i*width + index_x_u=(i+1)*width + + index_y_d=j*height + index_y_u=(j+1)*height + + if index_x_u>img_w: + index_x_u=img_w + index_x_d=img_w-width + if index_y_u>img_h: + index_y_u=img_h + index_y_d=img_h-height + + + img_patch=img[index_y_d:index_y_u,index_x_d:index_x_u,:] + label_patch=label[index_y_d:index_y_u,index_x_d:index_x_u,:] + + cv2.imwrite(dir_img_f+'/img_'+str(indexer)+'.png', img_patch ) + cv2.imwrite(dir_seg_f+'/img_'+str(indexer)+'.png' , label_patch ) + indexer+=1 + return indexer + + + +def get_patches_num_scale(dir_img_f,dir_seg_f,img,label,height,width,indexer,scaler): + + + img_h=img.shape[0] + img_w=img.shape[1] + + height_scale=int(height*scaler) + width_scale=int(width*scaler) + + + nxf=img_w/float(width_scale) + nyf=img_h/float(height_scale) + + if nxf>int(nxf): + nxf=int(nxf)+1 + if nyf>int(nyf): + nyf=int(nyf)+1 + + nxf=int(nxf) + nyf=int(nyf) + + for i in range(nxf): + for j in range(nyf): + index_x_d=i*width_scale + index_x_u=(i+1)*width_scale + + index_y_d=j*height_scale + index_y_u=(j+1)*height_scale + + if index_x_u>img_w: + index_x_u=img_w + index_x_d=img_w-width_scale + if index_y_u>img_h: + index_y_u=img_h + index_y_d=img_h-height_scale + + + img_patch=img[index_y_d:index_y_u,index_x_d:index_x_u,:] + label_patch=label[index_y_d:index_y_u,index_x_d:index_x_u,:] + + img_patch=resize_image(img_patch,height,width) + label_patch=resize_image(label_patch,height,width) + + cv2.imwrite(dir_img_f+'/img_'+str(indexer)+'.png', img_patch ) + cv2.imwrite(dir_seg_f+'/img_'+str(indexer)+'.png' , label_patch ) + indexer+=1 + + return indexer + + + +def provide_patches(dir_img,dir_seg,dir_flow_train_imgs, + dir_flow_train_labels, + input_height,input_width,blur_k,blur_aug, + flip_aug,binarization,scaling,scales,flip_index, + scaling_bluring,scaling_binarization,rotation, + augmentation=False,patches=False): + + imgs_cv_train=np.array(os.listdir(dir_img)) + segs_cv_train=np.array(os.listdir(dir_seg)) + + indexer=0 + for im, seg_i in tqdm(zip(imgs_cv_train,segs_cv_train)): + img_name=im.split('.')[0] + + if not patches: + cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', resize_image(cv2.imread(dir_img+'/'+im),input_height,input_width ) ) + cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png' , resize_image(cv2.imread(dir_seg+'/'+img_name+'.png'),input_height,input_width ) ) + indexer+=1 + + if augmentation: + if rotation: + cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', + rotation_90( resize_image(cv2.imread(dir_img+'/'+im), + input_height,input_width) ) ) + + + cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png', + rotation_90 ( resize_image(cv2.imread(dir_seg+'/'+img_name+'.png'), + input_height,input_width) ) ) + indexer+=1 + + if flip_aug: + for f_i in flip_index: + cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', + resize_image(cv2.flip(cv2.imread(dir_img+'/'+im),f_i),input_height,input_width) ) + + cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png' , + resize_image(cv2.flip(cv2.imread(dir_seg+'/'+img_name+'.png'),f_i),input_height,input_width) ) + indexer+=1 + + if blur_aug: + for blur_i in blur_k: + cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', + (resize_image(bluring(cv2.imread(dir_img+'/'+im),blur_i),input_height,input_width) ) ) + + cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png' , + resize_image(cv2.imread(dir_seg+'/'+img_name+'.png'),input_height,input_width) ) + indexer+=1 + + + if binarization: + cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', + resize_image(otsu_copy( cv2.imread(dir_img+'/'+im)),input_height,input_width )) + + cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png', + resize_image( cv2.imread(dir_seg+'/'+img_name+'.png'),input_height,input_width )) + indexer+=1 + + + + + + + if patches: + + indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, + cv2.imread(dir_img+'/'+im),cv2.imread(dir_seg+'/'+img_name+'.png'), + input_height,input_width,indexer=indexer) + + if augmentation: + + if rotation: + + + indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, + rotation_90( cv2.imread(dir_img+'/'+im) ), + rotation_90( cv2.imread(dir_seg+'/'+img_name+'.png') ), + input_height,input_width,indexer=indexer) + if flip_aug: + for f_i in flip_index: + + indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, + cv2.flip( cv2.imread(dir_img+'/'+im) , f_i), + cv2.flip( cv2.imread(dir_seg+'/'+img_name+'.png') ,f_i), + input_height,input_width,indexer=indexer) + if blur_aug: + for blur_i in blur_k: + indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, + bluring( cv2.imread(dir_img+'/'+im) , blur_i), + cv2.imread(dir_seg+'/'+img_name+'.png'), + input_height,input_width,indexer=indexer) + + + if scaling: + for sc_ind in scales: + indexer=get_patches_num_scale(dir_flow_train_imgs,dir_flow_train_labels, + cv2.imread(dir_img+'/'+im) , + cv2.imread(dir_seg+'/'+img_name+'.png'), + input_height,input_width,indexer=indexer,scaler=sc_ind) + if binarization: + + indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, + otsu_copy( cv2.imread(dir_img+'/'+im)), + cv2.imread(dir_seg+'/'+img_name+'.png'), + input_height,input_width,indexer=indexer) + + + + if scaling_bluring: + for sc_ind in scales: + for blur_i in blur_k: + indexer=get_patches_num_scale(dir_flow_train_imgs,dir_flow_train_labels, + bluring( cv2.imread(dir_img+'/'+im) , blur_i) , + cv2.imread(dir_seg+'/'+img_name+'.png') , + input_height,input_width,indexer=indexer,scaler=sc_ind) + + if scaling_binarization: + for sc_ind in scales: + indexer=get_patches_num_scale(dir_flow_train_imgs,dir_flow_train_labels, + otsu_copy( cv2.imread(dir_img+'/'+im)) , + cv2.imread(dir_seg+'/'+img_name+'.png'), + input_height,input_width,indexer=indexer,scaler=sc_ind) + + + + + + From 8084e136ba67513caa4e5309be70caff2b75fbea Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Tue, 10 Dec 2019 11:57:37 +0100 Subject: [PATCH 012/123] Update README --- train/README.md | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/train/README.md b/train/README.md index c4dc27e..16e5dce 100644 --- a/train/README.md +++ b/train/README.md @@ -4,16 +4,21 @@ # Ground truth format - Lables for each pixel is identified by a number . So if you have a binary case n_classes should be set to 2 and + Lables for each pixel is identified by a number . So if you have a + binary case n_classes should be set to 2 and labels should be 0 and 1 for each class and pixel. - In the case of multiclass just set n_classes to the number of classes you have and the try to produce the labels + In the case of multiclass just set n_classes to the number of classes + you have and the try to produce the labels by pixels set from 0 , 1 ,2 .., n_classes-1. The labels format should be png. If you have an image label for binary case it should look like this: - Label: [ [[1 0 0 1], [1 0 0 1] ,[1 0 0 1]], [[1 0 0 1], [1 0 0 1] ,[1 0 0 1]] ,[[1 0 0 1], [1 0 0 1] ,[1 0 0 1]] ] - this means that you have an image by 3*4*3 and pixel[0,0] belongs to class 1 and pixel[0,1] to class 0. + Label: [ [[1 0 0 1], [1 0 0 1] ,[1 0 0 1]], + [[1 0 0 1], [1 0 0 1] ,[1 0 0 1]] , + [[1 0 0 1], [1 0 0 1] ,[1 0 0 1]] ] + This means that you have an image by 3*4*3 and pixel[0,0] belongs + to class 1 and pixel[0,1] to class 0. # Training , evaluation and output train and evaluation folder should have subfolder of images and labels. @@ -21,6 +26,11 @@ # Patches - if you want to train your model with patches, the height and width of patches should be defined and also number of + if you want to train your model with patches, the height and width of + patches should be defined and also number of batchs (how many patches should be seen by model by each iteration). - In the case that model should see the image once, like page extraction, the patches should be set to false. \ No newline at end of file + In the case that model should see the image once, like page extraction, + the patches should be set to false. +# Pretrained encoder +Download weights from this limk and add it to pretrained_model folder. +https://file.spk-berlin.de:8443/pretrained_encoder/ From 4229ad92d7460ed9fdc63a2837527586fde18de3 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Tue, 10 Dec 2019 11:58:02 +0100 Subject: [PATCH 013/123] Update README.md --- train/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train/README.md b/train/README.md index 16e5dce..3ba90a1 100644 --- a/train/README.md +++ b/train/README.md @@ -32,5 +32,5 @@ In the case that model should see the image once, like page extraction, the patches should be set to false. # Pretrained encoder -Download weights from this limk and add it to pretrained_model folder. +Download weights from this link and add it to pretrained_model folder. https://file.spk-berlin.de:8443/pretrained_encoder/ From b5f9b9c54ad4ad746ab93bc7f81652f9158d75e5 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Tue, 10 Dec 2019 14:01:55 +0100 Subject: [PATCH 014/123] Update main.py --- train/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train/train.py b/train/train.py index 07c7418..baeb847 100644 --- a/train/train.py +++ b/train/train.py @@ -169,7 +169,7 @@ def run(n_classes,n_epochs,input_height, model.fit_generator( train_gen, - steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs))/n_batch), + steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs))/n_batch)-1, validation_data=val_gen, validation_steps=1, epochs=n_epochs) From df536d62c04825e05ea5aceb6067616db3b357a8 Mon Sep 17 00:00:00 2001 From: Clemens Neudecker <952378+cneud@users.noreply.github.com> Date: Tue, 10 Dec 2019 16:39:41 +0100 Subject: [PATCH 015/123] Add LICENSE --- train/LICENSE | 201 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 201 insertions(+) create mode 100644 train/LICENSE diff --git a/train/LICENSE b/train/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/train/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. From ad1360b179e0f4c39882bdd119e1760c7747db4d Mon Sep 17 00:00:00 2001 From: Clemens Neudecker <952378+cneud@users.noreply.github.com> Date: Wed, 15 Jan 2020 19:37:27 +0100 Subject: [PATCH 016/123] Update README.md --- train/README.md | 65 +++++++++++++++++++++++++++++-------------------- 1 file changed, 38 insertions(+), 27 deletions(-) diff --git a/train/README.md b/train/README.md index 3ba90a1..4c49f39 100644 --- a/train/README.md +++ b/train/README.md @@ -1,36 +1,47 @@ -# Train - just run: python train.py with config_params.json +# Pixelwise Segmentation +> Pixelwise segmentation for document images + +## Introduction +This repository contains the source code for training an encoder model for document image segmentation. + +## Installation +Either clone the repository via `git clone https://github.com/qurator-spk/sbb_pixelwise_segmentation.git` or download and unpack the [ZIP](https://github.com/qurator-spk/sbb_pixelwise_segmentation/archive/master.zip). + +## Usage + +### Train +To train a model, run: ``python train.py with config_params.json`` + +### Ground truth format +Lables for each pixel are identified by a number. So if you have a +binary case, ``n_classes`` should be set to ``2`` and labels should +be ``0`` and ``1`` for each class and pixel. + +In the case of multiclass, just set ``n_classes`` to the number of classes +you have and the try to produce the labels by pixels set from ``0 , 1 ,2 .., n_classes-1``. +The labels format should be png. - -# Ground truth format - - Lables for each pixel is identified by a number . So if you have a - binary case n_classes should be set to 2 and - labels should be 0 and 1 for each class and pixel. - In the case of multiclass just set n_classes to the number of classes - you have and the try to produce the labels - by pixels set from 0 , 1 ,2 .., n_classes-1. - The labels format should be png. - - If you have an image label for binary case it should look like this: +If you have an image label for a binary case it should look like this: Label: [ [[1 0 0 1], [1 0 0 1] ,[1 0 0 1]], [[1 0 0 1], [1 0 0 1] ,[1 0 0 1]] , [[1 0 0 1], [1 0 0 1] ,[1 0 0 1]] ] - This means that you have an image by 3*4*3 and pixel[0,0] belongs - to class 1 and pixel[0,1] to class 0. -# Training , evaluation and output - train and evaluation folder should have subfolder of images and labels. - And output folder should be empty folder which the output model will be written there. + This means that you have an image by `3*4*3` and `pixel[0,0]` belongs + to class `1` and `pixel[0,1]` belongs to class `0`. + +### Training , evaluation and output +The train and evaluation folders should contain subfolders of images and labels. +The output folder should be an empty folder where the output model will be written to. # Patches +If you want to train your model with patches, the height and width of +the patches should be defined and also the number of batches (how many patches +should be seen by the model in each iteration). + +In the case that the model should see the image once, like page extraction, +patches should be set to ``false``. - if you want to train your model with patches, the height and width of - patches should be defined and also number of - batchs (how many patches should be seen by model by each iteration). - In the case that model should see the image once, like page extraction, - the patches should be set to false. -# Pretrained encoder -Download weights from this link and add it to pretrained_model folder. -https://file.spk-berlin.de:8443/pretrained_encoder/ +### Pretrained encoder +Download our pretrained weights and add them to a ``pretrained_model`` folder: +~~https://file.spk-berlin.de:8443/pretrained_encoder/~~ From 66d7138343edc9fe3d7d918198a1f20b4112e42b Mon Sep 17 00:00:00 2001 From: Clemens Neudecker <952378+cneud@users.noreply.github.com> Date: Wed, 15 Jan 2020 19:43:31 +0100 Subject: [PATCH 017/123] Update README.md --- train/README.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/train/README.md b/train/README.md index 4c49f39..18495a5 100644 --- a/train/README.md +++ b/train/README.md @@ -7,6 +7,9 @@ This repository contains the source code for training an encoder model for docum ## Installation Either clone the repository via `git clone https://github.com/qurator-spk/sbb_pixelwise_segmentation.git` or download and unpack the [ZIP](https://github.com/qurator-spk/sbb_pixelwise_segmentation/archive/master.zip). +### Pretrained encoder +Download our pretrained weights and add them to a ``pretrained_model`` folder: +~~https://file.spk-berlin.de:8443/pretrained_encoder/~~ ## Usage ### Train @@ -34,7 +37,7 @@ If you have an image label for a binary case it should look like this: The train and evaluation folders should contain subfolders of images and labels. The output folder should be an empty folder where the output model will be written to. -# Patches +### Patches If you want to train your model with patches, the height and width of the patches should be defined and also the number of batches (how many patches should be seen by the model in each iteration). @@ -42,6 +45,4 @@ should be seen by the model in each iteration). In the case that the model should see the image once, like page extraction, patches should be set to ``false``. -### Pretrained encoder -Download our pretrained weights and add them to a ``pretrained_model`` folder: -~~https://file.spk-berlin.de:8443/pretrained_encoder/~~ + From 4e216475dca544515488071f035cde639d053584 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Thu, 16 Jan 2020 15:53:39 +0100 Subject: [PATCH 018/123] Update README.md --- train/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train/README.md b/train/README.md index 18495a5..ede05dd 100644 --- a/train/README.md +++ b/train/README.md @@ -9,7 +9,7 @@ Either clone the repository via `git clone https://github.com/qurator-spk/sbb_pi ### Pretrained encoder Download our pretrained weights and add them to a ``pretrained_model`` folder: -~~https://file.spk-berlin.de:8443/pretrained_encoder/~~ +https://qurator-data.de/pretrained_encoder/ ## Usage ### Train From b54285b19684e6a6b86a52448dc9afd4a38e95ea Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Thu, 16 Jan 2020 16:05:06 +0100 Subject: [PATCH 019/123] Update README.md --- train/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train/README.md b/train/README.md index ede05dd..d0d26d6 100644 --- a/train/README.md +++ b/train/README.md @@ -9,7 +9,7 @@ Either clone the repository via `git clone https://github.com/qurator-spk/sbb_pi ### Pretrained encoder Download our pretrained weights and add them to a ``pretrained_model`` folder: -https://qurator-data.de/pretrained_encoder/ +https://qurator-data.de/sbb_pixelwise_segmentation/pretrained_encoder/ ## Usage ### Train From 070c2e046259441b712d11be21eb26c6db191b71 Mon Sep 17 00:00:00 2001 From: vahid Date: Tue, 22 Jun 2021 14:20:51 -0400 Subject: [PATCH 020/123] first updates, padding, rotations --- train/config_params.json | 22 ++-- train/train.py | 183 ++++++++++++++------------- train/utils.py | 265 +++++++++++++++++++++++++++++++-------- 3 files changed, 319 insertions(+), 151 deletions(-) diff --git a/train/config_params.json b/train/config_params.json index 5066444..d8f1ac5 100644 --- a/train/config_params.json +++ b/train/config_params.json @@ -1,24 +1,24 @@ { - "n_classes" : 2, - "n_epochs" : 2, + "n_classes" : 3, + "n_epochs" : 1, "input_height" : 448, - "input_width" : 896, + "input_width" : 672, "weight_decay" : 1e-6, - "n_batch" : 1, + "n_batch" : 2, "learning_rate": 1e-4, "patches" : true, "pretraining" : true, - "augmentation" : false, + "augmentation" : true, "flip_aug" : false, - "elastic_aug" : false, - "blur_aug" : false, + "blur_aug" : true, "scaling" : false, "binarization" : false, "scaling_bluring" : false, "scaling_binarization" : false, + "scaling_flip" : false, "rotation": false, - "weighted_loss": true, - "dir_train": "../train", - "dir_eval": "../eval", - "dir_output": "../output" + "rotation_not_90": false, + "dir_train": "/home/vahid/Documents/handwrittens_train/train", + "dir_eval": "/home/vahid/Documents/handwrittens_train/eval", + "dir_output": "/home/vahid/Documents/handwrittens_train/output" } diff --git a/train/train.py b/train/train.py index baeb847..c256d83 100644 --- a/train/train.py +++ b/train/train.py @@ -8,7 +8,7 @@ from sacred import Experiment from models import * from utils import * from metrics import * - +from keras.models import load_model def configuration(): keras.backend.clear_session() @@ -47,7 +47,6 @@ def config_params(): # extraction this should be set to false since model should see all image. augmentation=False flip_aug=False # Flip image (augmentation). - elastic_aug=False # Elastic transformation (augmentation). blur_aug=False # Blur patches of image (augmentation). scaling=False # Scaling of patches (augmentation) will be imposed if this set to true. binarization=False # Otsu thresholding. Used for augmentation in the case of binary case like textline prediction. For multicases should not be applied. @@ -55,110 +54,116 @@ def config_params(): dir_eval=None # Directory of validation dataset (sub-folders should be named images and labels). dir_output=None # Directory of output where the model should be saved. pretraining=False # Set true to load pretrained weights of resnet50 encoder. - weighted_loss=False # Set True if classes are unbalanced and you want to use weighted loss function. scaling_bluring=False - rotation: False scaling_binarization=False + scaling_flip=False + thetha=[10,-10] blur_k=['blur','guass','median'] # Used in order to blur image. Used for augmentation. - scales=[0.9 , 1.1 ] # Scale patches with these scales. Used for augmentation. - flip_index=[0,1] # Flip image. Used for augmentation. + scales= [ 0.5, 2 ] # Scale patches with these scales. Used for augmentation. + flip_index=[0,1,-1] # Flip image. Used for augmentation. @ex.automain def run(n_classes,n_epochs,input_height, - input_width,weight_decay,weighted_loss, - n_batch,patches,augmentation,flip_aug,blur_aug,scaling, binarization, + input_width,weight_decay, + n_batch,patches,augmentation,flip_aug + ,blur_aug,scaling, binarization, blur_k,scales,dir_train, scaling_bluring,scaling_binarization,rotation, + rotation_not_90,thetha,scaling_flip, flip_index,dir_eval ,dir_output,pretraining,learning_rate): - dir_img,dir_seg=get_dirs_or_files(dir_train) - dir_img_val,dir_seg_val=get_dirs_or_files(dir_eval) + data_is_provided = False - # make first a directory in output for both training and evaluations in order to flow data from these directories. - dir_train_flowing=os.path.join(dir_output,'train') - dir_eval_flowing=os.path.join(dir_output,'eval') - - dir_flow_train_imgs=os.path.join(dir_train_flowing,'images') - dir_flow_train_labels=os.path.join(dir_train_flowing,'labels') - - dir_flow_eval_imgs=os.path.join(dir_eval_flowing,'images') - dir_flow_eval_labels=os.path.join(dir_eval_flowing,'labels') - - if os.path.isdir(dir_train_flowing): - os.system('rm -rf '+dir_train_flowing) - os.makedirs(dir_train_flowing) + if data_is_provided: + dir_train_flowing=os.path.join(dir_output,'train') + dir_eval_flowing=os.path.join(dir_output,'eval') + + dir_flow_train_imgs=os.path.join(dir_train_flowing,'images') + dir_flow_train_labels=os.path.join(dir_train_flowing,'labels') + + dir_flow_eval_imgs=os.path.join(dir_eval_flowing,'images') + dir_flow_eval_labels=os.path.join(dir_eval_flowing,'labels') + + configuration() + else: - os.makedirs(dir_train_flowing) + dir_img,dir_seg=get_dirs_or_files(dir_train) + dir_img_val,dir_seg_val=get_dirs_or_files(dir_eval) - if os.path.isdir(dir_eval_flowing): - os.system('rm -rf '+dir_eval_flowing) - os.makedirs(dir_eval_flowing) - else: - os.makedirs(dir_eval_flowing) + # make first a directory in output for both training and evaluations in order to flow data from these directories. + dir_train_flowing=os.path.join(dir_output,'train') + dir_eval_flowing=os.path.join(dir_output,'eval') - - os.mkdir(dir_flow_train_imgs) - os.mkdir(dir_flow_train_labels) - - os.mkdir(dir_flow_eval_imgs) - os.mkdir(dir_flow_eval_labels) - - - - #set the gpu configuration - configuration() - - - #writing patches into a sub-folder in order to be flowed from directory. - provide_patches(dir_img,dir_seg,dir_flow_train_imgs, - dir_flow_train_labels, - input_height,input_width,blur_k,blur_aug, - flip_aug,binarization,scaling,scales,flip_index, - scaling_bluring,scaling_binarization,rotation, - augmentation=augmentation,patches=patches) - - provide_patches(dir_img_val,dir_seg_val,dir_flow_eval_imgs, - dir_flow_eval_labels, - input_height,input_width,blur_k,blur_aug, - flip_aug,binarization,scaling,scales,flip_index, - scaling_bluring,scaling_binarization,rotation, - augmentation=False,patches=patches) + dir_flow_train_imgs=os.path.join(dir_train_flowing,'images/') + dir_flow_train_labels=os.path.join(dir_train_flowing,'labels/') - if weighted_loss: - weights=np.zeros(n_classes) - for obj in os.listdir(dir_seg): - label_obj=cv2.imread(dir_seg+'/'+obj) - label_obj_one_hot=get_one_hot( label_obj,label_obj.shape[0],label_obj.shape[1],n_classes) - weights+=(label_obj_one_hot.sum(axis=0)).sum(axis=0) + dir_flow_eval_imgs=os.path.join(dir_eval_flowing,'images/') + dir_flow_eval_labels=os.path.join(dir_eval_flowing,'labels/') + + if os.path.isdir(dir_train_flowing): + os.system('rm -rf '+dir_train_flowing) + os.makedirs(dir_train_flowing) + else: + os.makedirs(dir_train_flowing) + + if os.path.isdir(dir_eval_flowing): + os.system('rm -rf '+dir_eval_flowing) + os.makedirs(dir_eval_flowing) + else: + os.makedirs(dir_eval_flowing) - weights=1.00/weights + os.mkdir(dir_flow_train_imgs) + os.mkdir(dir_flow_train_labels) - weights=weights/float(np.sum(weights)) - weights=weights/float(np.min(weights)) - weights=weights/float(np.sum(weights)) + os.mkdir(dir_flow_eval_imgs) + os.mkdir(dir_flow_eval_labels) + + + #set the gpu configuration + configuration() - - + + #writing patches into a sub-folder in order to be flowed from directory. + provide_patches(dir_img,dir_seg,dir_flow_train_imgs, + dir_flow_train_labels, + input_height,input_width,blur_k,blur_aug, + flip_aug,binarization,scaling,scales,flip_index, + scaling_bluring,scaling_binarization,rotation, + rotation_not_90,thetha,scaling_flip, + augmentation=augmentation,patches=patches) - #get our model. - model = resnet50_unet(n_classes, input_height, input_width,weight_decay,pretraining) + provide_patches(dir_img_val,dir_seg_val,dir_flow_eval_imgs, + dir_flow_eval_labels, + input_height,input_width,blur_k,blur_aug, + flip_aug,binarization,scaling,scales,flip_index, + scaling_bluring,scaling_binarization,rotation, + rotation_not_90,thetha,scaling_flip, + augmentation=False,patches=patches) + + + continue_train = False + + if continue_train: + model_dir_start = '/home/vahid/Documents/struktur_full_data/output_multi/model_0.h5' + model = load_model (model_dir_start, compile = True, custom_objects={'soft_dice_loss': soft_dice_loss}) + index_start = 1 + else: + #get our model. + index_start = 0 + model = resnet50_unet(n_classes, input_height, input_width,weight_decay,pretraining) #if you want to see the model structure just uncomment model summary. #model.summary() - if not weighted_loss: - model.compile(loss='categorical_crossentropy', - optimizer = Adam(lr=learning_rate),metrics=['accuracy']) - if weighted_loss: - model.compile(loss=weighted_categorical_crossentropy(weights), - optimizer = Adam(lr=learning_rate),metrics=['accuracy']) - - mc = keras.callbacks.ModelCheckpoint('weights{epoch:08d}.h5', - save_weights_only=True, period=1) - + + #model.compile(loss='categorical_crossentropy', + #optimizer = Adam(lr=learning_rate),metrics=['accuracy']) + + model.compile(loss=soft_dice_loss, + optimizer = Adam(lr=learning_rate),metrics=['accuracy']) #generating train and evaluation data train_gen = data_gen(dir_flow_train_imgs,dir_flow_train_labels, batch_size = n_batch, @@ -166,20 +171,20 @@ def run(n_classes,n_epochs,input_height, val_gen = data_gen(dir_flow_eval_imgs,dir_flow_eval_labels, batch_size = n_batch, input_height=input_height, input_width=input_width,n_classes=n_classes ) + for i in range(index_start, n_epochs+index_start): + model.fit_generator( + train_gen, + steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs))/n_batch)-1, + validation_data=val_gen, + validation_steps=1, + epochs=1) + model.save(dir_output+'/'+'model_'+str(i)+'.h5') - model.fit_generator( - train_gen, - steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs))/n_batch)-1, - validation_data=val_gen, - validation_steps=1, - epochs=n_epochs) - - os.system('rm -rf '+dir_train_flowing) os.system('rm -rf '+dir_eval_flowing) - model.save(dir_output+'/'+'model'+'.h5') + #model.save(dir_output+'/'+'model'+'.h5') diff --git a/train/utils.py b/train/utils.py index afdc9e5..a77444e 100644 --- a/train/utils.py +++ b/train/utils.py @@ -6,7 +6,8 @@ from scipy.ndimage.interpolation import map_coordinates from scipy.ndimage.filters import gaussian_filter import random from tqdm import tqdm - +import imutils +import math @@ -19,6 +20,79 @@ def bluring(img_in,kind): img_blur=cv2.blur(img_in,(5,5)) return img_blur +def elastic_transform(image, alpha, sigma,seedj, random_state=None): + + """Elastic deformation of images as described in [Simard2003]_. + .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for + Convolutional Neural Networks applied to Visual Document Analysis", in + Proc. of the International Conference on Document Analysis and + Recognition, 2003. + """ + if random_state is None: + random_state = np.random.RandomState(seedj) + + shape = image.shape + dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha + dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha + dz = np.zeros_like(dx) + + x, y, z = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]), np.arange(shape[2])) + indices = np.reshape(y+dy, (-1, 1)), np.reshape(x+dx, (-1, 1)), np.reshape(z, (-1, 1)) + + distored_image = map_coordinates(image, indices, order=1, mode='reflect') + return distored_image.reshape(image.shape) + +def rotation_90(img): + img_rot=np.zeros((img.shape[1],img.shape[0],img.shape[2])) + img_rot[:,:,0]=img[:,:,0].T + img_rot[:,:,1]=img[:,:,1].T + img_rot[:,:,2]=img[:,:,2].T + return img_rot + +def rotatedRectWithMaxArea(w, h, angle): + """ + Given a rectangle of size wxh that has been rotated by 'angle' (in + radians), computes the width and height of the largest possible + axis-aligned rectangle (maximal area) within the rotated rectangle. + """ + if w <= 0 or h <= 0: + return 0,0 + + width_is_longer = w >= h + side_long, side_short = (w,h) if width_is_longer else (h,w) + + # since the solutions for angle, -angle and 180-angle are all the same, + # if suffices to look at the first quadrant and the absolute values of sin,cos: + sin_a, cos_a = abs(math.sin(angle)), abs(math.cos(angle)) + if side_short <= 2.*sin_a*cos_a*side_long or abs(sin_a-cos_a) < 1e-10: + # half constrained case: two crop corners touch the longer side, + # the other two corners are on the mid-line parallel to the longer line + x = 0.5*side_short + wr,hr = (x/sin_a,x/cos_a) if width_is_longer else (x/cos_a,x/sin_a) + else: + # fully constrained case: crop touches all 4 sides + cos_2a = cos_a*cos_a - sin_a*sin_a + wr,hr = (w*cos_a - h*sin_a)/cos_2a, (h*cos_a - w*sin_a)/cos_2a + + return wr,hr + +def rotate_max_area(image,rotated, rotated_label,angle): + """ image: cv2 image matrix object + angle: in degree + """ + wr, hr = rotatedRectWithMaxArea(image.shape[1], image.shape[0], + math.radians(angle)) + h, w, _ = rotated.shape + y1 = h//2 - int(hr/2) + y2 = y1 + int(hr) + x1 = w//2 - int(wr/2) + x2 = x1 + int(wr) + return rotated[y1:y2, x1:x2],rotated_label[y1:y2, x1:x2] +def rotation_not_90_func(img,label,thetha): + rotated=imutils.rotate(img,thetha) + rotated_label=imutils.rotate(label,thetha) + return rotate_max_area(img, rotated,rotated_label,thetha) + def color_images(seg, n_classes): ann_u=range(n_classes) if len(np.shape(seg))==3: @@ -65,7 +139,7 @@ def IoU(Yi,y_predi): return mIoU def data_gen(img_folder, mask_folder, batch_size,input_height, input_width,n_classes): c = 0 - n = os.listdir(img_folder) #List of training images + n = [f for f in os.listdir(img_folder) if not f.startswith('.')]# os.listdir(img_folder) #List of training images random.shuffle(n) while True: img = np.zeros((batch_size, input_height, input_width, 3)).astype('float') @@ -73,18 +147,26 @@ def data_gen(img_folder, mask_folder, batch_size,input_height, input_width,n_cla for i in range(c, c+batch_size): #initially from 0 to 16, c = 0. #print(img_folder+'/'+n[i]) - filename=n[i].split('.')[0] - train_img = cv2.imread(img_folder+'/'+n[i])/255. - train_img = cv2.resize(train_img, (input_width, input_height),interpolation=cv2.INTER_NEAREST)# Read an image from folder and resize - - img[i-c] = train_img #add to array - img[0], img[1], and so on. - train_mask = cv2.imread(mask_folder+'/'+filename+'.png') - #print(mask_folder+'/'+filename+'.png') - #print(train_mask.shape) - train_mask = get_one_hot( resize_image(train_mask,input_height,input_width),input_height,input_width,n_classes) - #train_mask = train_mask.reshape(224, 224, 1) # Add extra dimension for parity with train_img size [512 * 512 * 3] - - mask[i-c] = train_mask + + try: + filename=n[i].split('.')[0] + + train_img = cv2.imread(img_folder+'/'+n[i])/255. + train_img = cv2.resize(train_img, (input_width, input_height),interpolation=cv2.INTER_NEAREST)# Read an image from folder and resize + + img[i-c] = train_img #add to array - img[0], img[1], and so on. + train_mask = cv2.imread(mask_folder+'/'+filename+'.png') + #print(mask_folder+'/'+filename+'.png') + #print(train_mask.shape) + train_mask = get_one_hot( resize_image(train_mask,input_height,input_width),input_height,input_width,n_classes) + #train_mask = train_mask.reshape(224, 224, 1) # Add extra dimension for parity with train_img size [512 * 512 * 3] + + mask[i-c] = train_mask + except: + img[i-c] = np.ones((input_height, input_width, 3)).astype('float') + mask[i-c] = np.zeros((input_height, input_width, n_classes)).astype('float') + + c+=batch_size if(c+batch_size>=len(os.listdir(img_folder))): @@ -104,16 +186,10 @@ def otsu_copy(img): img_r[:,:,1]=threshold1 img_r[:,:,2]=threshold1 return img_r - -def rotation_90(img): - img_rot=np.zeros((img.shape[1],img.shape[0],img.shape[2])) - img_rot[:,:,0]=img[:,:,0].T - img_rot[:,:,1]=img[:,:,1].T - img_rot[:,:,2]=img[:,:,2].T - return img_rot - def get_patches(dir_img_f,dir_seg_f,img,label,height,width,indexer): + if img.shape[0]int(nxf): + nxf=int(nxf)+1 + if nyf>int(nyf): + nyf=int(nyf)+1 + + nxf=int(nxf) + nyf=int(nyf) + + for i in range(nxf): + for j in range(nyf): + index_x_d=i*width_scale + index_x_u=(i+1)*width_scale + + index_y_d=j*height_scale + index_y_u=(j+1)*height_scale + + if index_x_u>img_w: + index_x_u=img_w + index_x_d=img_w-width_scale + if index_y_u>img_h: + index_y_u=img_h + index_y_d=img_h-height_scale + + + img_patch=img[index_y_d:index_y_u,index_x_d:index_x_u,:] + label_patch=label[index_y_d:index_y_u,index_x_d:index_x_u,:] + + #img_patch=resize_image(img_patch,height,width) + #label_patch=resize_image(label_patch,height,width) + + cv2.imwrite(dir_img_f+'/img_'+str(indexer)+'.png', img_patch ) + cv2.imwrite(dir_seg_f+'/img_'+str(indexer)+'.png' , label_patch ) + indexer+=1 + + return indexer def provide_patches(dir_img,dir_seg,dir_flow_train_imgs, @@ -211,6 +366,7 @@ def provide_patches(dir_img,dir_seg,dir_flow_train_imgs, input_height,input_width,blur_k,blur_aug, flip_aug,binarization,scaling,scales,flip_index, scaling_bluring,scaling_binarization,rotation, + rotation_not_90,thetha,scaling_flip, augmentation=False,patches=False): imgs_cv_train=np.array(os.listdir(dir_img)) @@ -218,25 +374,15 @@ def provide_patches(dir_img,dir_seg,dir_flow_train_imgs, indexer=0 for im, seg_i in tqdm(zip(imgs_cv_train,segs_cv_train)): + #print(im, seg_i) img_name=im.split('.')[0] - + print(img_name,'img_name') if not patches: cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', resize_image(cv2.imread(dir_img+'/'+im),input_height,input_width ) ) cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png' , resize_image(cv2.imread(dir_seg+'/'+img_name+'.png'),input_height,input_width ) ) indexer+=1 if augmentation: - if rotation: - cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', - rotation_90( resize_image(cv2.imread(dir_img+'/'+im), - input_height,input_width) ) ) - - - cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png', - rotation_90 ( resize_image(cv2.imread(dir_seg+'/'+img_name+'.png'), - input_height,input_width) ) ) - indexer+=1 - if flip_aug: for f_i in flip_index: cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', @@ -270,10 +416,10 @@ def provide_patches(dir_img,dir_seg,dir_flow_train_imgs, if patches: - + indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, - cv2.imread(dir_img+'/'+im),cv2.imread(dir_seg+'/'+img_name+'.png'), - input_height,input_width,indexer=indexer) + cv2.imread(dir_img+'/'+im),cv2.imread(dir_seg+'/'+img_name+'.png'), + input_height,input_width,indexer=indexer) if augmentation: @@ -284,29 +430,37 @@ def provide_patches(dir_img,dir_seg,dir_flow_train_imgs, rotation_90( cv2.imread(dir_img+'/'+im) ), rotation_90( cv2.imread(dir_seg+'/'+img_name+'.png') ), input_height,input_width,indexer=indexer) + + if rotation_not_90: + + for thetha_i in thetha: + img_max_rotated,label_max_rotated=rotation_not_90_func(cv2.imread(dir_img+'/'+im),cv2.imread(dir_seg+'/'+img_name+'.png'),thetha_i) + indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, + img_max_rotated, + label_max_rotated, + input_height,input_width,indexer=indexer) if flip_aug: for f_i in flip_index: - indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, cv2.flip( cv2.imread(dir_img+'/'+im) , f_i), cv2.flip( cv2.imread(dir_seg+'/'+img_name+'.png') ,f_i), input_height,input_width,indexer=indexer) if blur_aug: for blur_i in blur_k: + indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, bluring( cv2.imread(dir_img+'/'+im) , blur_i), cv2.imread(dir_seg+'/'+img_name+'.png'), - input_height,input_width,indexer=indexer) - + input_height,input_width,indexer=indexer) + if scaling: for sc_ind in scales: - indexer=get_patches_num_scale(dir_flow_train_imgs,dir_flow_train_labels, - cv2.imread(dir_img+'/'+im) , - cv2.imread(dir_seg+'/'+img_name+'.png'), + indexer=get_patches_num_scale_new(dir_flow_train_imgs,dir_flow_train_labels, + cv2.imread(dir_img+'/'+im) , + cv2.imread(dir_seg+'/'+img_name+'.png'), input_height,input_width,indexer=indexer,scaler=sc_ind) if binarization: - indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, otsu_copy( cv2.imread(dir_img+'/'+im)), cv2.imread(dir_seg+'/'+img_name+'.png'), @@ -317,17 +471,26 @@ def provide_patches(dir_img,dir_seg,dir_flow_train_imgs, if scaling_bluring: for sc_ind in scales: for blur_i in blur_k: - indexer=get_patches_num_scale(dir_flow_train_imgs,dir_flow_train_labels, + indexer=get_patches_num_scale_new(dir_flow_train_imgs,dir_flow_train_labels, bluring( cv2.imread(dir_img+'/'+im) , blur_i) , cv2.imread(dir_seg+'/'+img_name+'.png') , input_height,input_width,indexer=indexer,scaler=sc_ind) if scaling_binarization: for sc_ind in scales: - indexer=get_patches_num_scale(dir_flow_train_imgs,dir_flow_train_labels, - otsu_copy( cv2.imread(dir_img+'/'+im)) , - cv2.imread(dir_seg+'/'+img_name+'.png'), + indexer=get_patches_num_scale_new(dir_flow_train_imgs,dir_flow_train_labels, + otsu_copy( cv2.imread(dir_img+'/'+im)) , + cv2.imread(dir_seg+'/'+img_name+'.png'), input_height,input_width,indexer=indexer,scaler=sc_ind) + + if scaling_flip: + for sc_ind in scales: + for f_i in flip_index: + indexer=get_patches_num_scale_new(dir_flow_train_imgs,dir_flow_train_labels, + cv2.flip( cv2.imread(dir_img+'/'+im) , f_i) , + cv2.flip(cv2.imread(dir_seg+'/'+img_name+'.png') ,f_i) , + input_height,input_width,indexer=indexer,scaler=sc_ind) + From 8884b90f052c9d29d10dcce7f8636d41437181b8 Mon Sep 17 00:00:00 2001 From: vahid Date: Tue, 22 Jun 2021 18:47:59 -0400 Subject: [PATCH 021/123] continue training, losses and etc --- train/config_params.json | 14 +++++--- train/train.py | 77 ++++++++++++++++++++++++++++++---------- train/utils.py | 2 -- 3 files changed, 69 insertions(+), 24 deletions(-) diff --git a/train/config_params.json b/train/config_params.json index d8f1ac5..eaa50e1 100644 --- a/train/config_params.json +++ b/train/config_params.json @@ -1,6 +1,6 @@ { "n_classes" : 3, - "n_epochs" : 1, + "n_epochs" : 2, "input_height" : 448, "input_width" : 672, "weight_decay" : 1e-6, @@ -8,16 +8,22 @@ "learning_rate": 1e-4, "patches" : true, "pretraining" : true, - "augmentation" : true, + "augmentation" : false, "flip_aug" : false, - "blur_aug" : true, - "scaling" : false, + "blur_aug" : false, + "scaling" : true, "binarization" : false, "scaling_bluring" : false, "scaling_binarization" : false, "scaling_flip" : false, "rotation": false, "rotation_not_90": false, + "continue_training": false, + "index_start": 0, + "dir_of_start_model": " ", + "weighted_loss": false, + "is_loss_soft_dice": false, + "data_is_provided": false, "dir_train": "/home/vahid/Documents/handwrittens_train/train", "dir_eval": "/home/vahid/Documents/handwrittens_train/eval", "dir_output": "/home/vahid/Documents/handwrittens_train/output" diff --git a/train/train.py b/train/train.py index c256d83..0cc5ef3 100644 --- a/train/train.py +++ b/train/train.py @@ -9,6 +9,7 @@ from models import * from utils import * from metrics import * from keras.models import load_model +from tqdm import tqdm def configuration(): keras.backend.clear_session() @@ -61,19 +62,24 @@ def config_params(): blur_k=['blur','guass','median'] # Used in order to blur image. Used for augmentation. scales= [ 0.5, 2 ] # Scale patches with these scales. Used for augmentation. flip_index=[0,1,-1] # Flip image. Used for augmentation. - + continue_training = False # If + index_start = 0 + dir_of_start_model = '' + is_loss_soft_dice = False + weighted_loss = False + data_is_provided = False @ex.automain def run(n_classes,n_epochs,input_height, - input_width,weight_decay, + input_width,weight_decay,weighted_loss, + index_start,dir_of_start_model,is_loss_soft_dice, n_batch,patches,augmentation,flip_aug ,blur_aug,scaling, binarization, - blur_k,scales,dir_train, + blur_k,scales,dir_train,data_is_provided, scaling_bluring,scaling_binarization,rotation, - rotation_not_90,thetha,scaling_flip, + rotation_not_90,thetha,scaling_flip,continue_training, flip_index,dir_eval ,dir_output,pretraining,learning_rate): - data_is_provided = False if data_is_provided: dir_train_flowing=os.path.join(dir_output,'train') @@ -143,12 +149,43 @@ def run(n_classes,n_epochs,input_height, augmentation=False,patches=patches) - continue_train = False + + if weighted_loss: + weights=np.zeros(n_classes) + if data_is_provided: + for obj in os.listdir(dir_flow_train_labels): + try: + label_obj=cv2.imread(dir_flow_train_labels+'/'+obj) + label_obj_one_hot=get_one_hot( label_obj,label_obj.shape[0],label_obj.shape[1],n_classes) + weights+=(label_obj_one_hot.sum(axis=0)).sum(axis=0) + except: + pass + else: + + for obj in os.listdir(dir_seg): + try: + label_obj=cv2.imread(dir_seg+'/'+obj) + label_obj_one_hot=get_one_hot( label_obj,label_obj.shape[0],label_obj.shape[1],n_classes) + weights+=(label_obj_one_hot.sum(axis=0)).sum(axis=0) + except: + pass + - if continue_train: - model_dir_start = '/home/vahid/Documents/struktur_full_data/output_multi/model_0.h5' - model = load_model (model_dir_start, compile = True, custom_objects={'soft_dice_loss': soft_dice_loss}) - index_start = 1 + weights=1.00/weights + + weights=weights/float(np.sum(weights)) + weights=weights/float(np.min(weights)) + weights=weights/float(np.sum(weights)) + + + + if continue_training: + if is_loss_soft_dice: + model = load_model (dir_of_start_model, compile = True, custom_objects={'soft_dice_loss': soft_dice_loss}) + if weighted_loss: + model = load_model (dir_of_start_model, compile = True, custom_objects={'loss': weighted_categorical_crossentropy(weights)}) + if not is_loss_soft_dice and not weighted_loss: + model = load_model (dir_of_start_model, compile = True) else: #get our model. index_start = 0 @@ -158,12 +195,16 @@ def run(n_classes,n_epochs,input_height, #model.summary() + if not is_loss_soft_dice and not weighted_loss: + model.compile(loss='categorical_crossentropy', + optimizer = Adam(lr=learning_rate),metrics=['accuracy']) + if is_loss_soft_dice: + model.compile(loss=soft_dice_loss, + optimizer = Adam(lr=learning_rate),metrics=['accuracy']) - #model.compile(loss='categorical_crossentropy', - #optimizer = Adam(lr=learning_rate),metrics=['accuracy']) - - model.compile(loss=soft_dice_loss, - optimizer = Adam(lr=learning_rate),metrics=['accuracy']) + if weighted_loss: + model.compile(loss=weighted_categorical_crossentropy(weights), + optimizer = Adam(lr=learning_rate),metrics=['accuracy']) #generating train and evaluation data train_gen = data_gen(dir_flow_train_imgs,dir_flow_train_labels, batch_size = n_batch, @@ -171,7 +212,7 @@ def run(n_classes,n_epochs,input_height, val_gen = data_gen(dir_flow_eval_imgs,dir_flow_eval_labels, batch_size = n_batch, input_height=input_height, input_width=input_width,n_classes=n_classes ) - for i in range(index_start, n_epochs+index_start): + for i in tqdm(range(index_start, n_epochs+index_start)): model.fit_generator( train_gen, steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs))/n_batch)-1, @@ -181,8 +222,8 @@ def run(n_classes,n_epochs,input_height, model.save(dir_output+'/'+'model_'+str(i)+'.h5') - os.system('rm -rf '+dir_train_flowing) - os.system('rm -rf '+dir_eval_flowing) + #os.system('rm -rf '+dir_train_flowing) + #os.system('rm -rf '+dir_eval_flowing) #model.save(dir_output+'/'+'model'+'.h5') diff --git a/train/utils.py b/train/utils.py index a77444e..19ab46e 100644 --- a/train/utils.py +++ b/train/utils.py @@ -374,9 +374,7 @@ def provide_patches(dir_img,dir_seg,dir_flow_train_imgs, indexer=0 for im, seg_i in tqdm(zip(imgs_cv_train,segs_cv_train)): - #print(im, seg_i) img_name=im.split('.')[0] - print(img_name,'img_name') if not patches: cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', resize_image(cv2.imread(dir_img+'/'+im),input_height,input_width ) ) cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png' , resize_image(cv2.imread(dir_seg+'/'+img_name+'.png'),input_height,input_width ) ) From 2d9ba854674db7169c3aceb4fca562b96bbed1f1 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Wed, 23 Jun 2021 07:25:49 -0400 Subject: [PATCH 022/123] Update README.md --- train/README.md | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/train/README.md b/train/README.md index d0d26d6..87a59ce 100644 --- a/train/README.md +++ b/train/README.md @@ -23,14 +23,16 @@ be ``0`` and ``1`` for each class and pixel. In the case of multiclass, just set ``n_classes`` to the number of classes you have and the try to produce the labels by pixels set from ``0 , 1 ,2 .., n_classes-1``. The labels format should be png. +Our lables are 3 channel png images but only information of first channel is used. +If you have an image label with height and width of 10, for a binary case the first channel should look like this: -If you have an image label for a binary case it should look like this: + Label: [ [1, 0, 0, 1, 1, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ..., + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] ] - Label: [ [[1 0 0 1], [1 0 0 1] ,[1 0 0 1]], - [[1 0 0 1], [1 0 0 1] ,[1 0 0 1]] , - [[1 0 0 1], [1 0 0 1] ,[1 0 0 1]] ] - - This means that you have an image by `3*4*3` and `pixel[0,0]` belongs + This means that you have an image by `10*10*3` and `pixel[0,0]` belongs to class `1` and `pixel[0,1]` belongs to class `0`. ### Training , evaluation and output From 15407393e20a5c66556a0ab8e364f2206156ad27 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Wed, 23 Jun 2021 07:55:36 -0400 Subject: [PATCH 023/123] Update README.md --- train/README.md | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/train/README.md b/train/README.md index 87a59ce..464a9a4 100644 --- a/train/README.md +++ b/train/README.md @@ -39,12 +39,15 @@ If you have an image label with height and width of 10, for a binary case the fi The train and evaluation folders should contain subfolders of images and labels. The output folder should be an empty folder where the output model will be written to. -### Patches -If you want to train your model with patches, the height and width of -the patches should be defined and also the number of batches (how many patches -should be seen by the model in each iteration). - -In the case that the model should see the image once, like page extraction, -patches should be set to ``false``. +### Parameter configuration +* patches: If you want to break input images into smaller patches (input size of the model) you need to set this parameter to ``true``. In the case that the model should see the image once, like page extraction, patches should be set to ``false``. +* n_batch: Number of batches at each iteration. +* n_classes: Number of classes. In the case of binary classification this should be 2. +* n_epochs: Number of epochs. +* input_height: This indicates the height of model's input. +* input_width: This indicates the width of model's input. +* weight_decay: Weight decay of l2 regularization of model layers. +* augmentation: If you want to apply any kind of augmentation this parameter should first set to ``true``. +* flip_aug: If ``true``, different types of filp will applied on image. Type of flips is given by "flip_index" in train.py file. From 491cdbf9342ffeebabe088b60371c2f18dd8cfaf Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Wed, 23 Jun 2021 08:21:12 -0400 Subject: [PATCH 024/123] Update README.md --- train/README.md | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/train/README.md b/train/README.md index 464a9a4..af8595f 100644 --- a/train/README.md +++ b/train/README.md @@ -48,6 +48,18 @@ The output folder should be an empty folder where the output model will be writt * input_width: This indicates the width of model's input. * weight_decay: Weight decay of l2 regularization of model layers. * augmentation: If you want to apply any kind of augmentation this parameter should first set to ``true``. -* flip_aug: If ``true``, different types of filp will applied on image. Type of flips is given by "flip_index" in train.py file. +* flip_aug: If ``true``, different types of filp will be applied on image. Type of flips is given with "flip_index" in train.py file. +* blur_aug: If ``true``, different types of blurring will be applied on image. Type of blurrings is given with "blur_k" in train.py file. +* scaling: If ``true``, scaling will be applied on image. Scale of scaling is given with "scales" in train.py file. +* rotation_not_90: If ``true``, rotation (not 90 degree) will be applied on image. Rothation angles are given with "thetha" in train.py file. +* rotation: If ``true``, 90 degree rotation will be applied on image. +* binarization: If ``true``,Otsu thresholding will be applied to augment the input data with binarized images. +* scaling_bluring: If ``true``, combination of scaling and blurring will be applied on image. +* scaling_binarization: If ``true``, combination of scaling and binarization will be applied on image. +* scaling_flip: If ``true``, combination of scaling and flip will be applied on image. +* continue_training: If ``true``, it means that you have already trained a model and you would like to continue the training. So it is needed to provide the dir of trained model with "dir_of_start_model" and index for naming the models. For example if you have already trained for 3 epochs then your last index is 2 and if you want to continue from model_1.h5, you can set "index_start" to 3 to start naming model with index 3. +* weighted_loss: If ``true``, this means that you want to apply weighted categorical_crossentropy as loss fucntion. Be carefull if you set to ``true``the parameter "is_loss_soft_dice" should be ``false`` +* data_is_provided: If you have already provided the input data you can set this to ``true``. Be sure that the train and eval data should be in "dir_output". Since when once we provide training data we resize and augment them and then we write them in sub-directories train and eval in "dir_output". +* dir_train: This is the directory of "images" and "labels" (dir_train should include two subdirectories with names of images and labels ) for raw images and labels. Namely they are not prepared (not resize and not augmented) yet for training the model. When we run this tool these raw data will be transformed to suitable size needed for the model and they will be written in "dir_output" in train and eval directories. Each of train and eval include "images" and "labels" sub-directories. From 76c75d1365ee31e5637c763c89e664e7bbc45b0d Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Wed, 23 Jun 2021 08:22:03 -0400 Subject: [PATCH 025/123] Update README.md --- train/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train/README.md b/train/README.md index af8595f..c38aea1 100644 --- a/train/README.md +++ b/train/README.md @@ -59,7 +59,7 @@ The output folder should be an empty folder where the output model will be writt * scaling_flip: If ``true``, combination of scaling and flip will be applied on image. * continue_training: If ``true``, it means that you have already trained a model and you would like to continue the training. So it is needed to provide the dir of trained model with "dir_of_start_model" and index for naming the models. For example if you have already trained for 3 epochs then your last index is 2 and if you want to continue from model_1.h5, you can set "index_start" to 3 to start naming model with index 3. * weighted_loss: If ``true``, this means that you want to apply weighted categorical_crossentropy as loss fucntion. Be carefull if you set to ``true``the parameter "is_loss_soft_dice" should be ``false`` -* data_is_provided: If you have already provided the input data you can set this to ``true``. Be sure that the train and eval data should be in "dir_output". Since when once we provide training data we resize and augment them and then we write them in sub-directories train and eval in "dir_output". +* data_is_provided: If you have already provided the input data you can set this to ``true``. Be sure that the train and eval data are in "dir_output". Since when once we provide training data we resize and augment them and then we write them in sub-directories train and eval in "dir_output". * dir_train: This is the directory of "images" and "labels" (dir_train should include two subdirectories with names of images and labels ) for raw images and labels. Namely they are not prepared (not resize and not augmented) yet for training the model. When we run this tool these raw data will be transformed to suitable size needed for the model and they will be written in "dir_output" in train and eval directories. Each of train and eval include "images" and "labels" sub-directories. From 310a709ac7d2b1632580b53d6b4b3c127230e808 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Wed, 23 Jun 2021 08:23:20 -0400 Subject: [PATCH 026/123] Update README.md --- train/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train/README.md b/train/README.md index c38aea1..5272def 100644 --- a/train/README.md +++ b/train/README.md @@ -60,6 +60,6 @@ The output folder should be an empty folder where the output model will be writt * continue_training: If ``true``, it means that you have already trained a model and you would like to continue the training. So it is needed to provide the dir of trained model with "dir_of_start_model" and index for naming the models. For example if you have already trained for 3 epochs then your last index is 2 and if you want to continue from model_1.h5, you can set "index_start" to 3 to start naming model with index 3. * weighted_loss: If ``true``, this means that you want to apply weighted categorical_crossentropy as loss fucntion. Be carefull if you set to ``true``the parameter "is_loss_soft_dice" should be ``false`` * data_is_provided: If you have already provided the input data you can set this to ``true``. Be sure that the train and eval data are in "dir_output". Since when once we provide training data we resize and augment them and then we write them in sub-directories train and eval in "dir_output". -* dir_train: This is the directory of "images" and "labels" (dir_train should include two subdirectories with names of images and labels ) for raw images and labels. Namely they are not prepared (not resize and not augmented) yet for training the model. When we run this tool these raw data will be transformed to suitable size needed for the model and they will be written in "dir_output" in train and eval directories. Each of train and eval include "images" and "labels" sub-directories. +* dir_train: This is the directory of "images" and "labels" (dir_train should include two subdirectories with names of images and labels ) for raw images and labels. Namely they are not prepared (not resized and not augmented) yet for training the model. When we run this tool these raw data will be transformed to suitable size needed for the model and they will be written in "dir_output" in train and eval directories. Each of train and eval include "images" and "labels" sub-directories. From b1c8bdf10624e3580c46105c2f323a0bc14b8178 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Tue, 29 Jun 2021 07:19:32 -0400 Subject: [PATCH 027/123] Update README.md --- train/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/train/README.md b/train/README.md index 5272def..363ba21 100644 --- a/train/README.md +++ b/train/README.md @@ -34,6 +34,8 @@ If you have an image label with height and width of 10, for a binary case the fi This means that you have an image by `10*10*3` and `pixel[0,0]` belongs to class `1` and `pixel[0,1]` belongs to class `0`. + + A small sample of training data for binarization experiment can be found here https://qurator-data.de/binarization_training_data_sample/ which contains images and lables folders. ### Training , evaluation and output The train and evaluation folders should contain subfolders of images and labels. From 49853bb291ff048874c8d0d8a4683968211b9ac8 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Tue, 29 Jun 2021 07:21:34 -0400 Subject: [PATCH 028/123] Update README.md --- train/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train/README.md b/train/README.md index 363ba21..529d7c7 100644 --- a/train/README.md +++ b/train/README.md @@ -35,7 +35,7 @@ If you have an image label with height and width of 10, for a binary case the fi This means that you have an image by `10*10*3` and `pixel[0,0]` belongs to class `1` and `pixel[0,1]` belongs to class `0`. - A small sample of training data for binarization experiment can be found here https://qurator-data.de/binarization_training_data_sample/ which contains images and lables folders. + A small sample of training data for binarization experiment can be found here [Training data sample](https://qurator-data.de/binarization_training_data_sample/) which contains images and lables folders. ### Training , evaluation and output The train and evaluation folders should contain subfolders of images and labels. From 09c0d5e318e1115b99dc3c9635179851370b54fe Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Tue, 29 Jun 2021 07:22:13 -0400 Subject: [PATCH 029/123] Update README.md --- train/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train/README.md b/train/README.md index 529d7c7..58f3eae 100644 --- a/train/README.md +++ b/train/README.md @@ -35,7 +35,7 @@ If you have an image label with height and width of 10, for a binary case the fi This means that you have an image by `10*10*3` and `pixel[0,0]` belongs to class `1` and `pixel[0,1]` belongs to class `0`. - A small sample of training data for binarization experiment can be found here [Training data sample](https://qurator-data.de/binarization_training_data_sample/) which contains images and lables folders. + A small sample of training data for binarization experiment can be found here, [Training data sample](https://qurator-data.de/binarization_training_data_sample/) , which contains images and lables folders. ### Training , evaluation and output The train and evaluation folders should contain subfolders of images and labels. From bcc900be1732ac5c9a94d2d99e37673c745d96af Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Tue, 29 Jun 2021 07:22:34 -0400 Subject: [PATCH 030/123] Update README.md --- train/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train/README.md b/train/README.md index 58f3eae..0f0eb55 100644 --- a/train/README.md +++ b/train/README.md @@ -35,7 +35,7 @@ If you have an image label with height and width of 10, for a binary case the fi This means that you have an image by `10*10*3` and `pixel[0,0]` belongs to class `1` and `pixel[0,1]` belongs to class `0`. - A small sample of training data for binarization experiment can be found here, [Training data sample](https://qurator-data.de/binarization_training_data_sample/) , which contains images and lables folders. + A small sample of training data for binarization experiment can be found here, [Training data sample](https://qurator-data.de/binarization_training_data_sample/), which contains images and lables folders. ### Training , evaluation and output The train and evaluation folders should contain subfolders of images and labels. From 083f5ae881436fad4e3f0e5b2caac068fa7bcf54 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Wed, 14 Jul 2021 06:01:33 -0400 Subject: [PATCH 031/123] Update README.md --- train/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train/README.md b/train/README.md index 0f0eb55..8acfa12 100644 --- a/train/README.md +++ b/train/README.md @@ -35,7 +35,7 @@ If you have an image label with height and width of 10, for a binary case the fi This means that you have an image by `10*10*3` and `pixel[0,0]` belongs to class `1` and `pixel[0,1]` belongs to class `0`. - A small sample of training data for binarization experiment can be found here, [Training data sample](https://qurator-data.de/binarization_training_data_sample/), which contains images and lables folders. + A small sample of training data for binarization experiment can be found here, [Training data sample](https://qurator-data.de/~vahid.rezanezhad/binarization_training_data_sample/), which contains images and lables folders. ### Training , evaluation and output The train and evaluation folders should contain subfolders of images and labels. From 5282caa3286f121f9195263d5419c3876c7d9b4f Mon Sep 17 00:00:00 2001 From: vahid Date: Mon, 22 Aug 2022 13:03:10 +0200 Subject: [PATCH 032/123] supposed to solve https://github.com/qurator-spk/sbb_binarization/issues/41 --- ..._model_load_pretrained_weights_and_save.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 train/build_model_load_pretrained_weights_and_save.py diff --git a/train/build_model_load_pretrained_weights_and_save.py b/train/build_model_load_pretrained_weights_and_save.py new file mode 100644 index 0000000..251e698 --- /dev/null +++ b/train/build_model_load_pretrained_weights_and_save.py @@ -0,0 +1,33 @@ +import os +import sys +import tensorflow as tf +import keras , warnings +from keras.optimizers import * +from sacred import Experiment +from models import * +from utils import * +from metrics import * + + + + +def configuration(): + gpu_options = tf.compat.v1.GPUOptions(allow_growth=True) + session = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options)) + + +if __name__=='__main__': + n_classes = 2 + input_height = 224 + input_width = 448 + weight_decay = 1e-6 + pretraining = False + dir_of_weights = 'model_bin_sbb_ens.h5' + + #configuration() + + model = resnet50_unet(n_classes, input_height, input_width,weight_decay,pretraining) + model.load_weights(dir_of_weights) + model.save('./name_in_another_python_version.h5') + + From 57dae564b359f905f636bb4579aff12d7e336d36 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Thu, 4 Apr 2024 11:26:28 +0200 Subject: [PATCH 033/123] adjusting to tf2 --- ..._model_load_pretrained_weights_and_save.py | 4 ++-- train/metrics.py | 2 +- train/models.py | 8 +++---- train/train.py | 24 +++++++------------ 4 files changed, 15 insertions(+), 23 deletions(-) diff --git a/train/build_model_load_pretrained_weights_and_save.py b/train/build_model_load_pretrained_weights_and_save.py index 251e698..3b1a577 100644 --- a/train/build_model_load_pretrained_weights_and_save.py +++ b/train/build_model_load_pretrained_weights_and_save.py @@ -1,8 +1,8 @@ import os import sys import tensorflow as tf -import keras , warnings -from keras.optimizers import * +import warnings +from tensorflow.keras.optimizers import * from sacred import Experiment from models import * from utils import * diff --git a/train/metrics.py b/train/metrics.py index c63cc22..1768960 100644 --- a/train/metrics.py +++ b/train/metrics.py @@ -1,4 +1,4 @@ -from keras import backend as K +from tensorflow.keras import backend as K import tensorflow as tf import numpy as np diff --git a/train/models.py b/train/models.py index 7c806b4..40a21a1 100644 --- a/train/models.py +++ b/train/models.py @@ -1,7 +1,7 @@ -from keras.models import * -from keras.layers import * -from keras import layers -from keras.regularizers import l2 +from tensorflow.keras.models import * +from tensorflow.keras.layers import * +from tensorflow.keras import layers +from tensorflow.keras.regularizers import l2 resnet50_Weights_path='./pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5' IMAGE_ORDERING ='channels_last' diff --git a/train/train.py b/train/train.py index 0cc5ef3..142b79b 100644 --- a/train/train.py +++ b/train/train.py @@ -1,29 +1,21 @@ import os import sys import tensorflow as tf -from keras.backend.tensorflow_backend import set_session -import keras , warnings -from keras.optimizers import * +from tensorflow.compat.v1.keras.backend import set_session +import warnings +from tensorflow.keras.optimizers import * from sacred import Experiment from models import * from utils import * from metrics import * -from keras.models import load_model +from tensorflow.keras.models import load_model from tqdm import tqdm def configuration(): - keras.backend.clear_session() - tf.reset_default_graph() - warnings.filterwarnings('ignore') - - os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID' - config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True) - - + config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True - config.gpu_options.per_process_gpu_memory_fraction=0.95#0.95 - config.gpu_options.visible_device_list="0" - set_session(tf.Session(config=config)) + session = tf.compat.v1.Session(config=config) + set_session(session) def get_dirs_or_files(input_data): if os.path.isdir(input_data): @@ -219,7 +211,7 @@ def run(n_classes,n_epochs,input_height, validation_data=val_gen, validation_steps=1, epochs=1) - model.save(dir_output+'/'+'model_'+str(i)+'.h5') + model.save(dir_output+'/'+'model_'+str(i)) #os.system('rm -rf '+dir_train_flowing) From ced1f851e267cf986d0e1dbf1bb63e15db31c823 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Thu, 4 Apr 2024 11:30:12 +0200 Subject: [PATCH 034/123] adding requirements --- train/requirements.txt | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 train/requirements.txt diff --git a/train/requirements.txt b/train/requirements.txt new file mode 100644 index 0000000..f804172 --- /dev/null +++ b/train/requirements.txt @@ -0,0 +1,7 @@ +tensorflow == 2.12.1 +sacred +opencv-python +seaborn +tqdm +imutils + From 45652294972f2ce7c8d1f473621901f322b9c4b6 Mon Sep 17 00:00:00 2001 From: cneud <952378+cneud@users.noreply.github.com> Date: Wed, 10 Apr 2024 20:03:02 +0200 Subject: [PATCH 035/123] use headless cv2 --- train/requirements.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/train/requirements.txt b/train/requirements.txt index f804172..cbe2d88 100644 --- a/train/requirements.txt +++ b/train/requirements.txt @@ -1,7 +1,6 @@ tensorflow == 2.12.1 sacred -opencv-python +opencv-python-headless seaborn tqdm imutils - From d0b039505956af90594d14a6535add8deeaa8583 Mon Sep 17 00:00:00 2001 From: cneud <952378+cneud@users.noreply.github.com> Date: Wed, 10 Apr 2024 20:26:26 +0200 Subject: [PATCH 036/123] add info on helpful tools (fix #14) --- train/README.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/train/README.md b/train/README.md index 8acfa12..89fa227 100644 --- a/train/README.md +++ b/train/README.md @@ -10,6 +10,16 @@ Either clone the repository via `git clone https://github.com/qurator-spk/sbb_pi ### Pretrained encoder Download our pretrained weights and add them to a ``pretrained_model`` folder: https://qurator-data.de/sbb_pixelwise_segmentation/pretrained_encoder/ + +### Helpful tools +* [`pagexml2img`](https://github.com/qurator-spk/page2img) +> Tool to extract 2-D or 3-D RGB images from PAGE-XML data. In the former case, the output will be 1 2-D image array which each class has filled with a pixel value. In the case of a 3-D RGB image, +each class will be defined with a RGB value and beside images, a text file of classes will also be produced. +* [`cocoSegmentationToPng`](https://github.com/nightrome/cocostuffapi/blob/17acf33aef3c6cc2d6aca46dcf084266c2778cf0/PythonAPI/pycocotools/cocostuffhelper.py#L130) +> Convert COCO GT or results for a single image to a segmentation map and write it to disk. +* [`ocrd-segment-extract-pages`](https://github.com/OCR-D/ocrd_segment/blob/master/ocrd_segment/extract_pages.py) +> Extract region classes and their colours in mask (pseg) images. Allows the color map as free dict parameter, and comes with a default that mimics PageViewer's coloring for quick debugging; it also warns when regions do overlap. + ## Usage ### Train From 39aa88669b98f364b33520ce45ff42b126be686c Mon Sep 17 00:00:00 2001 From: cneud <952378+cneud@users.noreply.github.com> Date: Wed, 10 Apr 2024 21:40:23 +0200 Subject: [PATCH 037/123] update parameter config docs (fix #11) --- train/train.py | 57 +++++++++++++++++++++++++------------------------- 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/train/train.py b/train/train.py index 142b79b..9f833e0 100644 --- a/train/train.py +++ b/train/train.py @@ -29,37 +29,36 @@ ex = Experiment() @ex.config def config_params(): - n_classes=None # Number of classes. If your case study is binary case the set it to 2 and otherwise give your number of cases. - n_epochs=1 - input_height=224*1 - input_width=224*1 + n_classes=None # Number of classes. In the case of binary classification this should be 2. + n_epochs=1 # Number of epochs. + input_height=224*1 # Height of model's input in pixels. + input_width=224*1 # Width of model's input in pixels. weight_decay=1e-6 # Weight decay of l2 regularization of model layers. n_batch=1 # Number of batches at each iteration. - learning_rate=1e-4 - patches=False # Make patches of image in order to use all information of image. In the case of page - # extraction this should be set to false since model should see all image. - augmentation=False - flip_aug=False # Flip image (augmentation). - blur_aug=False # Blur patches of image (augmentation). - scaling=False # Scaling of patches (augmentation) will be imposed if this set to true. - binarization=False # Otsu thresholding. Used for augmentation in the case of binary case like textline prediction. For multicases should not be applied. - dir_train=None # Directory of training dataset (sub-folders should be named images and labels). - dir_eval=None # Directory of validation dataset (sub-folders should be named images and labels). - dir_output=None # Directory of output where the model should be saved. - pretraining=False # Set true to load pretrained weights of resnet50 encoder. - scaling_bluring=False - scaling_binarization=False - scaling_flip=False - thetha=[10,-10] - blur_k=['blur','guass','median'] # Used in order to blur image. Used for augmentation. - scales= [ 0.5, 2 ] # Scale patches with these scales. Used for augmentation. - flip_index=[0,1,-1] # Flip image. Used for augmentation. - continue_training = False # If - index_start = 0 - dir_of_start_model = '' - is_loss_soft_dice = False - weighted_loss = False - data_is_provided = False + learning_rate=1e-4 # Set the learning rate. + patches=False # Divides input image into smaller patches (input size of the model) when set to true. For the model to see the full image, like page extraction, set this to false. + augmentation=False # To apply any kind of augmentation, this parameter must be set to true. + flip_aug=False # If true, different types of flipping will be applied to the image. Types of flips are defined with "flip_index" in train.py. + blur_aug=False # If true, different types of blurring will be applied to the image. Types of blur are defined with "blur_k" in train.py. + scaling=False # If true, scaling will be applied to the image. The amount of scaling is defined with "scales" in train.py. + binarization=False # If true, Otsu thresholding will be applied to augment the input with binarized images. + dir_train=None # Directory of training dataset with subdirectories having the names "images" and "labels". + dir_eval=None # Directory of validation dataset with subdirectories having the names "images" and "labels". + dir_output=None # Directory where the output model will be saved. + pretraining=False # Set to true to load pretrained weights of ResNet50 encoder. + scaling_bluring=False # If true, a combination of scaling and blurring will be applied to the image. + scaling_binarization=False # If true, a combination of scaling and binarization will be applied to the image. + scaling_flip=False # If true, a combination of scaling and flipping will be applied to the image. + thetha=[10,-10] # Rotate image by these angles for augmentation. + blur_k=['blur','gauss','median'] # Blur image for augmentation. + scales=[0.5,2] # Scale patches for augmentation. + flip_index=[0,1,-1] # Flip image for augmentation. + continue_training = False # Set to true if you would like to continue training an already trained a model. + index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3. + dir_of_start_model = '' # Directory containing pretrained encoder to continue training the model. + is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false. + weighted_loss = False # Use weighted categorical cross entropy as loss fucntion. When set to true, "is_loss_soft_dice" must be false. + data_is_provided = False # Only set this to true when you have already provided the input data and the train and eval data are in "dir_output". @ex.automain def run(n_classes,n_epochs,input_height, From 666a62622ee95f2c155eb6db6dfa58bd31f15971 Mon Sep 17 00:00:00 2001 From: cneud <952378+cneud@users.noreply.github.com> Date: Wed, 10 Apr 2024 22:20:23 +0200 Subject: [PATCH 038/123] code formatting with black; typos --- train/README.md | 6 +- ..._model_load_pretrained_weights_and_save.py | 14 +- train/config_params.json | 6 +- train/metrics.py | 209 ++--- train/models.py | 237 +++--- train/requirements.txt | 2 + train/train.py | 272 +++---- train/utils.py | 763 +++++++++--------- 8 files changed, 741 insertions(+), 768 deletions(-) diff --git a/train/README.md b/train/README.md index 89fa227..899c9a3 100644 --- a/train/README.md +++ b/train/README.md @@ -48,7 +48,7 @@ If you have an image label with height and width of 10, for a binary case the fi A small sample of training data for binarization experiment can be found here, [Training data sample](https://qurator-data.de/~vahid.rezanezhad/binarization_training_data_sample/), which contains images and lables folders. ### Training , evaluation and output -The train and evaluation folders should contain subfolders of images and labels. +The train and evaluation folders should contain subfolders of images and labels. The output folder should be an empty folder where the output model will be written to. ### Parameter configuration @@ -63,7 +63,7 @@ The output folder should be an empty folder where the output model will be writt * flip_aug: If ``true``, different types of filp will be applied on image. Type of flips is given with "flip_index" in train.py file. * blur_aug: If ``true``, different types of blurring will be applied on image. Type of blurrings is given with "blur_k" in train.py file. * scaling: If ``true``, scaling will be applied on image. Scale of scaling is given with "scales" in train.py file. -* rotation_not_90: If ``true``, rotation (not 90 degree) will be applied on image. Rothation angles are given with "thetha" in train.py file. +* rotation_not_90: If ``true``, rotation (not 90 degree) will be applied on image. Rotation angles are given with "thetha" in train.py file. * rotation: If ``true``, 90 degree rotation will be applied on image. * binarization: If ``true``,Otsu thresholding will be applied to augment the input data with binarized images. * scaling_bluring: If ``true``, combination of scaling and blurring will be applied on image. @@ -73,5 +73,3 @@ The output folder should be an empty folder where the output model will be writt * weighted_loss: If ``true``, this means that you want to apply weighted categorical_crossentropy as loss fucntion. Be carefull if you set to ``true``the parameter "is_loss_soft_dice" should be ``false`` * data_is_provided: If you have already provided the input data you can set this to ``true``. Be sure that the train and eval data are in "dir_output". Since when once we provide training data we resize and augment them and then we write them in sub-directories train and eval in "dir_output". * dir_train: This is the directory of "images" and "labels" (dir_train should include two subdirectories with names of images and labels ) for raw images and labels. Namely they are not prepared (not resized and not augmented) yet for training the model. When we run this tool these raw data will be transformed to suitable size needed for the model and they will be written in "dir_output" in train and eval directories. Each of train and eval include "images" and "labels" sub-directories. - - diff --git a/train/build_model_load_pretrained_weights_and_save.py b/train/build_model_load_pretrained_weights_and_save.py index 3b1a577..125611e 100644 --- a/train/build_model_load_pretrained_weights_and_save.py +++ b/train/build_model_load_pretrained_weights_and_save.py @@ -9,25 +9,21 @@ from utils import * from metrics import * - - def configuration(): gpu_options = tf.compat.v1.GPUOptions(allow_growth=True) session = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options)) -if __name__=='__main__': +if __name__ == '__main__': n_classes = 2 input_height = 224 input_width = 448 weight_decay = 1e-6 pretraining = False dir_of_weights = 'model_bin_sbb_ens.h5' - - #configuration() - - model = resnet50_unet(n_classes, input_height, input_width,weight_decay,pretraining) + + # configuration() + + model = resnet50_unet(n_classes, input_height, input_width, weight_decay, pretraining) model.load_weights(dir_of_weights) model.save('./name_in_another_python_version.h5') - - diff --git a/train/config_params.json b/train/config_params.json index eaa50e1..7505a81 100644 --- a/train/config_params.json +++ b/train/config_params.json @@ -24,7 +24,7 @@ "weighted_loss": false, "is_loss_soft_dice": false, "data_is_provided": false, - "dir_train": "/home/vahid/Documents/handwrittens_train/train", - "dir_eval": "/home/vahid/Documents/handwrittens_train/eval", - "dir_output": "/home/vahid/Documents/handwrittens_train/output" + "dir_train": "/train", + "dir_eval": "/eval", + "dir_output": "/output" } diff --git a/train/metrics.py b/train/metrics.py index 1768960..cd30b02 100644 --- a/train/metrics.py +++ b/train/metrics.py @@ -2,8 +2,8 @@ from tensorflow.keras import backend as K import tensorflow as tf import numpy as np -def focal_loss(gamma=2., alpha=4.): +def focal_loss(gamma=2., alpha=4.): gamma = float(gamma) alpha = float(alpha) @@ -37,8 +37,10 @@ def focal_loss(gamma=2., alpha=4.): fl = tf.multiply(alpha, tf.multiply(weight, ce)) reduced_fl = tf.reduce_max(fl, axis=1) return tf.reduce_mean(reduced_fl) + return focal_loss_fixed + def weighted_categorical_crossentropy(weights=None): """ weighted_categorical_crossentropy @@ -50,117 +52,131 @@ def weighted_categorical_crossentropy(weights=None): def loss(y_true, y_pred): labels_floats = tf.cast(y_true, tf.float32) - per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred) - + per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats, logits=y_pred) + if weights is not None: weight_mask = tf.maximum(tf.reduce_max(tf.constant( np.array(weights, dtype=np.float32)[None, None, None]) - * labels_floats, axis=-1), 1.0) + * labels_floats, axis=-1), 1.0) per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None] return tf.reduce_mean(per_pixel_loss) + return loss + + def image_categorical_cross_entropy(y_true, y_pred, weights=None): """ :param y_true: tensor of shape (batch_size, height, width) representing the ground truth. :param y_pred: tensor of shape (batch_size, height, width) representing the prediction. :return: The mean cross-entropy on softmaxed tensors. """ - + labels_floats = tf.cast(y_true, tf.float32) - per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred) - + per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats, logits=y_pred) + if weights is not None: weight_mask = tf.maximum( - tf.reduce_max(tf.constant( - np.array(weights, dtype=np.float32)[None, None, None]) - * labels_floats, axis=-1), 1.0) + tf.reduce_max(tf.constant( + np.array(weights, dtype=np.float32)[None, None, None]) + * labels_floats, axis=-1), 1.0) per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None] - - return tf.reduce_mean(per_pixel_loss) -def class_tversky(y_true, y_pred): - smooth = 1.0#1.00 - y_true = K.permute_dimensions(y_true, (3,1,2,0)) - y_pred = K.permute_dimensions(y_pred, (3,1,2,0)) + return tf.reduce_mean(per_pixel_loss) + + +def class_tversky(y_true, y_pred): + smooth = 1.0 # 1.00 + + y_true = K.permute_dimensions(y_true, (3, 1, 2, 0)) + y_pred = K.permute_dimensions(y_pred, (3, 1, 2, 0)) y_true_pos = K.batch_flatten(y_true) y_pred_pos = K.batch_flatten(y_pred) true_pos = K.sum(y_true_pos * y_pred_pos, 1) - false_neg = K.sum(y_true_pos * (1-y_pred_pos), 1) - false_pos = K.sum((1-y_true_pos)*y_pred_pos, 1) - alpha = 0.2#0.5 - beta=0.8 - return (true_pos + smooth)/(true_pos + alpha*false_neg + (beta)*false_pos + smooth) + false_neg = K.sum(y_true_pos * (1 - y_pred_pos), 1) + false_pos = K.sum((1 - y_true_pos) * y_pred_pos, 1) + alpha = 0.2 # 0.5 + beta = 0.8 + return (true_pos + smooth) / (true_pos + alpha * false_neg + beta * false_pos + smooth) -def focal_tversky_loss(y_true,y_pred): + +def focal_tversky_loss(y_true, y_pred): pt_1 = class_tversky(y_true, y_pred) - gamma =1.3#4./3.0#1.3#4.0/3.00# 0.75 - return K.sum(K.pow((1-pt_1), gamma)) + gamma = 1.3 # 4./3.0#1.3#4.0/3.00# 0.75 + return K.sum(K.pow((1 - pt_1), gamma)) + def generalized_dice_coeff2(y_true, y_pred): n_el = 1 - for dim in y_true.shape: + for dim in y_true.shape: n_el *= int(dim) n_cl = y_true.shape[-1] w = K.zeros(shape=(n_cl,)) - w = (K.sum(y_true, axis=(0,1,2)))/(n_el) - w = 1/(w**2+0.000001) - numerator = y_true*y_pred - numerator = w*K.sum(numerator,(0,1,2)) + w = (K.sum(y_true, axis=(0, 1, 2))) / n_el + w = 1 / (w ** 2 + 0.000001) + numerator = y_true * y_pred + numerator = w * K.sum(numerator, (0, 1, 2)) numerator = K.sum(numerator) - denominator = y_true+y_pred - denominator = w*K.sum(denominator,(0,1,2)) + denominator = y_true + y_pred + denominator = w * K.sum(denominator, (0, 1, 2)) denominator = K.sum(denominator) - return 2*numerator/denominator + return 2 * numerator / denominator + + def generalized_dice_coeff(y_true, y_pred): - axes = tuple(range(1, len(y_pred.shape)-1)) + axes = tuple(range(1, len(y_pred.shape) - 1)) Ncl = y_pred.shape[-1] w = K.zeros(shape=(Ncl,)) w = K.sum(y_true, axis=axes) - w = 1/(w**2+0.000001) + w = 1 / (w ** 2 + 0.000001) # Compute gen dice coef: - numerator = y_true*y_pred - numerator = w*K.sum(numerator,axes) + numerator = y_true * y_pred + numerator = w * K.sum(numerator, axes) numerator = K.sum(numerator) - denominator = y_true+y_pred - denominator = w*K.sum(denominator,axes) + denominator = y_true + y_pred + denominator = w * K.sum(denominator, axes) denominator = K.sum(denominator) - gen_dice_coef = 2*numerator/denominator + gen_dice_coef = 2 * numerator / denominator return gen_dice_coef + def generalized_dice_loss(y_true, y_pred): return 1 - generalized_dice_coeff2(y_true, y_pred) -def soft_dice_loss(y_true, y_pred, epsilon=1e-6): - ''' + + +def soft_dice_loss(y_true, y_pred, epsilon=1e-6): + """ Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions. Assumes the `channels_last` format. - + # Arguments y_true: b x X x Y( x Z...) x c One hot encoding of ground truth - y_pred: b x X x Y( x Z...) x c Network output, must sum to 1 over c channel (such as after softmax) + y_pred: b x X x Y( x Z...) x c Network output, must sum to 1 over c channel (such as after softmax) epsilon: Used for numerical stability to avoid divide by zero errors - + # References - V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation + V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation https://arxiv.org/abs/1606.04797 - More details on Dice loss formulation + More details on Dice loss formulation https://mediatum.ub.tum.de/doc/1395260/1395260.pdf (page 72) - + Adapted from https://github.com/Lasagne/Recipes/issues/99#issuecomment-347775022 - ''' - + """ + # skip the batch and class axis for calculating Dice score - axes = tuple(range(1, len(y_pred.shape)-1)) - + axes = tuple(range(1, len(y_pred.shape) - 1)) + numerator = 2. * K.sum(y_pred * y_true, axes) denominator = K.sum(K.square(y_pred) + K.square(y_true), axes) - return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch + return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch -def seg_metrics(y_true, y_pred, metric_name, metric_type='standard', drop_last = True, mean_per_class=False, verbose=False): + +def seg_metrics(y_true, y_pred, metric_name, metric_type='standard', drop_last=True, mean_per_class=False, + verbose=False): """ Compute mean metrics of two segmentation masks, via Keras. @@ -193,13 +209,13 @@ def seg_metrics(y_true, y_pred, metric_name, metric_type='standard', drop_last = H = height, N = number of classes """ - + flag_soft = (metric_type == 'soft') flag_naive_mean = (metric_type == 'naive') - + # always assume one or more classes num_classes = K.shape(y_true)[-1] - + if not flag_soft: # get one-hot encoded masks from y_pred (true masks should already be one-hot) y_pred = K.one_hot(K.argmax(y_pred), num_classes) @@ -211,29 +227,29 @@ def seg_metrics(y_true, y_pred, metric_name, metric_type='standard', drop_last = y_pred = K.cast(y_pred, 'float32') # intersection and union shapes are batch_size * n_classes (values = area in pixels) - axes = (1,2) # W,H axes of each image + axes = (1, 2) # W,H axes of each image intersection = K.sum(K.abs(y_true * y_pred), axis=axes) mask_sum = K.sum(K.abs(y_true), axis=axes) + K.sum(K.abs(y_pred), axis=axes) - union = mask_sum - intersection # or, np.logical_or(y_pred, y_true) for one-hot + union = mask_sum - intersection # or, np.logical_or(y_pred, y_true) for one-hot smooth = .001 iou = (intersection + smooth) / (union + smooth) - dice = 2 * (intersection + smooth)/(mask_sum + smooth) + dice = 2 * (intersection + smooth) / (mask_sum + smooth) metric = {'iou': iou, 'dice': dice}[metric_name] # define mask to be 0 when no pixels are present in either y_true or y_pred, 1 otherwise - mask = K.cast(K.not_equal(union, 0), 'float32') - + mask = K.cast(K.not_equal(union, 0), 'float32') + if drop_last: - metric = metric[:,:-1] - mask = mask[:,:-1] - + metric = metric[:, :-1] + mask = mask[:, :-1] + if verbose: print('intersection, union') print(K.eval(intersection), K.eval(union)) - print(K.eval(intersection/union)) - + print(K.eval(intersection / union)) + # return mean metrics: remaining axes are (batch, classes) if flag_naive_mean: return K.mean(metric) @@ -243,13 +259,14 @@ def seg_metrics(y_true, y_pred, metric_name, metric_type='standard', drop_last = non_zero = tf.greater(class_count, 0) non_zero_sum = tf.boolean_mask(K.sum(metric * mask, axis=0), non_zero) non_zero_count = tf.boolean_mask(class_count, non_zero) - + if verbose: print('Counts of inputs with class present, metrics for non-absent classes') print(K.eval(class_count), K.eval(non_zero_sum / non_zero_count)) - + return K.mean(non_zero_sum / non_zero_count) + def mean_iou(y_true, y_pred, **kwargs): """ Compute mean Intersection over Union of two segmentation masks, via Keras. @@ -257,65 +274,69 @@ def mean_iou(y_true, y_pred, **kwargs): Calls metrics_k(y_true, y_pred, metric_name='iou'), see there for allowed kwargs. """ return seg_metrics(y_true, y_pred, metric_name='iou', **kwargs) + + def Mean_IOU(y_true, y_pred): nb_classes = K.int_shape(y_pred)[-1] iou = [] true_pixels = K.argmax(y_true, axis=-1) pred_pixels = K.argmax(y_pred, axis=-1) void_labels = K.equal(K.sum(y_true, axis=-1), 0) - for i in range(0, nb_classes): # exclude first label (background) and last label (void) - true_labels = K.equal(true_pixels, i)# & ~void_labels - pred_labels = K.equal(pred_pixels, i)# & ~void_labels + for i in range(0, nb_classes): # exclude first label (background) and last label (void) + true_labels = K.equal(true_pixels, i) # & ~void_labels + pred_labels = K.equal(pred_pixels, i) # & ~void_labels inter = tf.to_int32(true_labels & pred_labels) union = tf.to_int32(true_labels | pred_labels) - legal_batches = K.sum(tf.to_int32(true_labels), axis=1)>0 - ious = K.sum(inter, axis=1)/K.sum(union, axis=1) - iou.append(K.mean(tf.gather(ious, indices=tf.where(legal_batches)))) # returns average IoU of the same objects + legal_batches = K.sum(tf.to_int32(true_labels), axis=1) > 0 + ious = K.sum(inter, axis=1) / K.sum(union, axis=1) + iou.append(K.mean(tf.gather(ious, indices=tf.where(legal_batches)))) # returns average IoU of the same objects iou = tf.stack(iou) legal_labels = ~tf.debugging.is_nan(iou) iou = tf.gather(iou, indices=tf.where(legal_labels)) return K.mean(iou) + def iou_vahid(y_true, y_pred): - nb_classes = tf.shape(y_true)[-1]+tf.to_int32(1) + nb_classes = tf.shape(y_true)[-1] + tf.to_int32(1) true_pixels = K.argmax(y_true, axis=-1) pred_pixels = K.argmax(y_pred, axis=-1) iou = [] - + for i in tf.range(nb_classes): - tp=K.sum( tf.to_int32( K.equal(true_pixels, i) & K.equal(pred_pixels, i) ) ) - fp=K.sum( tf.to_int32( K.not_equal(true_pixels, i) & K.equal(pred_pixels, i) ) ) - fn=K.sum( tf.to_int32( K.equal(true_pixels, i) & K.not_equal(pred_pixels, i) ) ) - iouh=tp/(tp+fp+fn) + tp = K.sum(tf.to_int32(K.equal(true_pixels, i) & K.equal(pred_pixels, i))) + fp = K.sum(tf.to_int32(K.not_equal(true_pixels, i) & K.equal(pred_pixels, i))) + fn = K.sum(tf.to_int32(K.equal(true_pixels, i) & K.not_equal(pred_pixels, i))) + iouh = tp / (tp + fp + fn) iou.append(iouh) return K.mean(iou) - - -def IoU_metric(Yi,y_predi): - ## mean Intersection over Union - ## Mean IoU = TP/(FN + TP + FP) + + +def IoU_metric(Yi, y_predi): + # mean Intersection over Union + # Mean IoU = TP/(FN + TP + FP) y_predi = np.argmax(y_predi, axis=3) y_testi = np.argmax(Yi, axis=3) IoUs = [] Nclass = int(np.max(Yi)) + 1 for c in range(Nclass): - TP = np.sum( (Yi == c)&(y_predi==c) ) - FP = np.sum( (Yi != c)&(y_predi==c) ) - FN = np.sum( (Yi == c)&(y_predi != c)) - IoU = TP/float(TP + FP + FN) + TP = np.sum((Yi == c) & (y_predi == c)) + FP = np.sum((Yi != c) & (y_predi == c)) + FN = np.sum((Yi == c) & (y_predi != c)) + IoU = TP / float(TP + FP + FN) IoUs.append(IoU) - return K.cast( np.mean(IoUs) ,dtype='float32' ) + return K.cast(np.mean(IoUs), dtype='float32') def IoU_metric_keras(y_true, y_pred): - ## mean Intersection over Union - ## Mean IoU = TP/(FN + TP + FP) + # mean Intersection over Union + # Mean IoU = TP/(FN + TP + FP) init = tf.global_variables_initializer() sess = tf.Session() sess.run(init) - + return IoU_metric(y_true.eval(session=sess), y_pred.eval(session=sess)) + def jaccard_distance_loss(y_true, y_pred, smooth=100): """ Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|) @@ -334,5 +355,3 @@ def jaccard_distance_loss(y_true, y_pred, smooth=100): sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1) jac = (intersection + smooth) / (sum_ - intersection + smooth) return (1 - jac) * smooth - - diff --git a/train/models.py b/train/models.py index 40a21a1..f06823e 100644 --- a/train/models.py +++ b/train/models.py @@ -3,19 +3,20 @@ from tensorflow.keras.layers import * from tensorflow.keras import layers from tensorflow.keras.regularizers import l2 -resnet50_Weights_path='./pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5' -IMAGE_ORDERING ='channels_last' -MERGE_AXIS=-1 +resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5' +IMAGE_ORDERING = 'channels_last' +MERGE_AXIS = -1 -def one_side_pad( x ): +def one_side_pad(x): x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x) if IMAGE_ORDERING == 'channels_first': - x = Lambda(lambda x : x[: , : , :-1 , :-1 ] )(x) + x = Lambda(lambda x: x[:, :, :-1, :-1])(x) elif IMAGE_ORDERING == 'channels_last': - x = Lambda(lambda x : x[: , :-1 , :-1 , : ] )(x) + x = Lambda(lambda x: x[:, :-1, :-1, :])(x) return x + def identity_block(input_tensor, kernel_size, filters, stage, block): """The identity block is the block that has no conv layer at shortcut. # Arguments @@ -28,7 +29,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block): Output tensor for the block. """ filters1, filters2, filters3 = filters - + if IMAGE_ORDERING == 'channels_last': bn_axis = 3 else: @@ -37,16 +38,16 @@ def identity_block(input_tensor, kernel_size, filters, stage, block): conv_name_base = 'res' + str(stage) + block + '_branch' bn_name_base = 'bn' + str(stage) + block + '_branch' - x = Conv2D(filters1, (1, 1) , data_format=IMAGE_ORDERING , name=conv_name_base + '2a')(input_tensor) + x = Conv2D(filters1, (1, 1), data_format=IMAGE_ORDERING, name=conv_name_base + '2a')(input_tensor) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) x = Activation('relu')(x) - x = Conv2D(filters2, kernel_size , data_format=IMAGE_ORDERING , + x = Conv2D(filters2, kernel_size, data_format=IMAGE_ORDERING, padding='same', name=conv_name_base + '2b')(x) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) x = Activation('relu')(x) - x = Conv2D(filters3 , (1, 1), data_format=IMAGE_ORDERING , name=conv_name_base + '2c')(x) + x = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, name=conv_name_base + '2c')(x) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) x = layers.add([x, input_tensor]) @@ -68,7 +69,7 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)) And the shortcut should have strides=(2,2) as well """ filters1, filters2, filters3 = filters - + if IMAGE_ORDERING == 'channels_last': bn_axis = 3 else: @@ -77,20 +78,20 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)) conv_name_base = 'res' + str(stage) + block + '_branch' bn_name_base = 'bn' + str(stage) + block + '_branch' - x = Conv2D(filters1, (1, 1) , data_format=IMAGE_ORDERING , strides=strides, + x = Conv2D(filters1, (1, 1), data_format=IMAGE_ORDERING, strides=strides, name=conv_name_base + '2a')(input_tensor) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) x = Activation('relu')(x) - x = Conv2D(filters2, kernel_size , data_format=IMAGE_ORDERING , padding='same', + x = Conv2D(filters2, kernel_size, data_format=IMAGE_ORDERING, padding='same', name=conv_name_base + '2b')(x) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) x = Activation('relu')(x) - x = Conv2D(filters3, (1, 1) , data_format=IMAGE_ORDERING , name=conv_name_base + '2c')(x) + x = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, name=conv_name_base + '2c')(x) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) - shortcut = Conv2D(filters3, (1, 1) , data_format=IMAGE_ORDERING , strides=strides, + shortcut = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, strides=strides, name=conv_name_base + '1')(input_tensor) shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut) @@ -99,12 +100,11 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)) return x -def resnet50_unet_light(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False): - assert input_height%32 == 0 - assert input_width%32 == 0 +def resnet50_unet_light(n_classes, input_height=224, input_width=224, weight_decay=1e-6, pretraining=False): + assert input_height % 32 == 0 + assert input_width % 32 == 0 - - img_input = Input(shape=(input_height,input_width , 3 )) + img_input = Input(shape=(input_height, input_width, 3)) if IMAGE_ORDERING == 'channels_last': bn_axis = 3 @@ -112,25 +112,24 @@ def resnet50_unet_light(n_classes,input_height=224,input_width=224,weight_decay= bn_axis = 1 x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input) - x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x) + x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2), kernel_regularizer=l2(weight_decay), + name='conv1')(x) f1 = x x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) x = Activation('relu')(x) - x = MaxPooling2D((3, 3) , data_format=IMAGE_ORDERING , strides=(2, 2))(x) - + x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x) x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') - f2 = one_side_pad(x ) - + f2 = one_side_pad(x) x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') - f3 = x + f3 = x x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') @@ -138,85 +137,72 @@ def resnet50_unet_light(n_classes,input_height=224,input_width=224,weight_decay= x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') - f4 = x + f4 = x x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') - f5 = x - + f5 = x if pretraining: - model=Model( img_input , x ).load_weights(resnet50_Weights_path) + model = Model(img_input, x).load_weights(resnet50_Weights_path) - - v512_2048 = Conv2D( 512 , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( f5 ) - v512_2048 = ( BatchNormalization(axis=bn_axis))(v512_2048) + v512_2048 = Conv2D(512, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f5) + v512_2048 = (BatchNormalization(axis=bn_axis))(v512_2048) v512_2048 = Activation('relu')(v512_2048) - - - v512_1024=Conv2D( 512 , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( f4 ) - v512_1024 = ( BatchNormalization(axis=bn_axis))(v512_1024) + v512_1024 = Conv2D(512, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f4) + v512_1024 = (BatchNormalization(axis=bn_axis))(v512_1024) v512_1024 = Activation('relu')(v512_1024) - - - o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(v512_2048) - o = ( concatenate([ o ,v512_1024],axis=MERGE_AXIS ) ) - o = ( ZeroPadding2D( (1,1), data_format=IMAGE_ORDERING))(o) - o = ( Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) - o = ( BatchNormalization(axis=bn_axis))(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(v512_2048) + o = (concatenate([o, v512_1024], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) o = Activation('relu')(o) - - o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) - o = ( concatenate([ o ,f3],axis=MERGE_AXIS ) ) - o = ( ZeroPadding2D( (1,1), data_format=IMAGE_ORDERING))(o) - o = ( Conv2D( 256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) - o = ( BatchNormalization(axis=bn_axis))(o) + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, f3], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) o = Activation('relu')(o) - - o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) - o = ( concatenate([o,f2],axis=MERGE_AXIS ) ) - o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING))(o) - o = ( Conv2D( 128 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay) ) )(o) - o = ( BatchNormalization(axis=bn_axis))(o) + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, f2], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) o = Activation('relu')(o) - - - o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) - o = ( concatenate([o,f1],axis=MERGE_AXIS ) ) - o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING ))(o) - o = ( Conv2D( 64 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) ))(o) - o = ( BatchNormalization(axis=bn_axis))(o) + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, f1], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) o = Activation('relu')(o) - - o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) - o = ( concatenate([o,img_input],axis=MERGE_AXIS ) ) - o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING ))(o) - o = ( Conv2D( 32 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) ))(o) - o = ( BatchNormalization(axis=bn_axis))(o) + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, img_input], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) o = Activation('relu')(o) - - - o = Conv2D( n_classes , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( o ) - o = ( BatchNormalization(axis=bn_axis))(o) + o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o) + o = (BatchNormalization(axis=bn_axis))(o) o = (Activation('softmax'))(o) - - model = Model( img_input , o ) + model = Model(img_input, o) return model -def resnet50_unet(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False): - assert input_height%32 == 0 - assert input_width%32 == 0 - - img_input = Input(shape=(input_height,input_width , 3 )) +def resnet50_unet(n_classes, input_height=224, input_width=224, weight_decay=1e-6, pretraining=False): + assert input_height % 32 == 0 + assert input_width % 32 == 0 + + img_input = Input(shape=(input_height, input_width, 3)) if IMAGE_ORDERING == 'channels_last': bn_axis = 3 @@ -224,25 +210,24 @@ def resnet50_unet(n_classes,input_height=224,input_width=224,weight_decay=1e-6,p bn_axis = 1 x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input) - x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x) + x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2), kernel_regularizer=l2(weight_decay), + name='conv1')(x) f1 = x x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) x = Activation('relu')(x) - x = MaxPooling2D((3, 3) , data_format=IMAGE_ORDERING , strides=(2, 2))(x) - + x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x) x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') - f2 = one_side_pad(x ) - + f2 = one_side_pad(x) x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') - f3 = x + f3 = x x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') @@ -250,68 +235,60 @@ def resnet50_unet(n_classes,input_height=224,input_width=224,weight_decay=1e-6,p x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') - f4 = x + f4 = x x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') - f5 = x + f5 = x if pretraining: - Model( img_input , x ).load_weights(resnet50_Weights_path) + Model(img_input, x).load_weights(resnet50_Weights_path) - v1024_2048 = Conv2D( 1024 , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( f5 ) - v1024_2048 = ( BatchNormalization(axis=bn_axis))(v1024_2048) + v1024_2048 = Conv2D(1024, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))( + f5) + v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048) v1024_2048 = Activation('relu')(v1024_2048) - - - o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(v1024_2048) - o = ( concatenate([ o ,f4],axis=MERGE_AXIS ) ) - o = ( ZeroPadding2D( (1,1), data_format=IMAGE_ORDERING))(o) - o = ( Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) - o = ( BatchNormalization(axis=bn_axis))(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(v1024_2048) + o = (concatenate([o, f4], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) o = Activation('relu')(o) - - o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) - o = ( concatenate([ o ,f3],axis=MERGE_AXIS ) ) - o = ( ZeroPadding2D( (1,1), data_format=IMAGE_ORDERING))(o) - o = ( Conv2D( 256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) - o = ( BatchNormalization(axis=bn_axis))(o) + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, f3], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) o = Activation('relu')(o) - - o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) - o = ( concatenate([o,f2],axis=MERGE_AXIS ) ) - o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING))(o) - o = ( Conv2D( 128 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay) ) )(o) - o = ( BatchNormalization(axis=bn_axis))(o) + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, f2], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) o = Activation('relu')(o) - - o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) - o = ( concatenate([o,f1],axis=MERGE_AXIS ) ) - o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING ))(o) - o = ( Conv2D( 64 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) ))(o) - o = ( BatchNormalization(axis=bn_axis))(o) + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, f1], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) o = Activation('relu')(o) - - o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) - o = ( concatenate([o,img_input],axis=MERGE_AXIS ) ) - o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING ))(o) - o = ( Conv2D( 32 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) ))(o) - o = ( BatchNormalization(axis=bn_axis))(o) + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, img_input], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) o = Activation('relu')(o) - - - o = Conv2D( n_classes , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( o ) - o = ( BatchNormalization(axis=bn_axis))(o) + + o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o) + o = (BatchNormalization(axis=bn_axis))(o) o = (Activation('softmax'))(o) - - model = Model( img_input , o ) - - + model = Model(img_input, o) return model diff --git a/train/requirements.txt b/train/requirements.txt index cbe2d88..20b6a32 100644 --- a/train/requirements.txt +++ b/train/requirements.txt @@ -4,3 +4,5 @@ opencv-python-headless seaborn tqdm imutils +numpy +scipy diff --git a/train/train.py b/train/train.py index 9f833e0..03faf46 100644 --- a/train/train.py +++ b/train/train.py @@ -11,12 +11,14 @@ from metrics import * from tensorflow.keras.models import load_model from tqdm import tqdm + def configuration(): config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True session = tf.compat.v1.Session(config=config) set_session(session) + def get_dirs_or_files(input_data): if os.path.isdir(input_data): image_input, labels_input = os.path.join(input_data, 'images/'), os.path.join(input_data, 'labels/') @@ -25,205 +27,187 @@ def get_dirs_or_files(input_data): assert os.path.isdir(labels_input), "{} is not a directory".format(labels_input) return image_input, labels_input + ex = Experiment() + @ex.config def config_params(): - n_classes=None # Number of classes. In the case of binary classification this should be 2. - n_epochs=1 # Number of epochs. - input_height=224*1 # Height of model's input in pixels. - input_width=224*1 # Width of model's input in pixels. - weight_decay=1e-6 # Weight decay of l2 regularization of model layers. - n_batch=1 # Number of batches at each iteration. - learning_rate=1e-4 # Set the learning rate. - patches=False # Divides input image into smaller patches (input size of the model) when set to true. For the model to see the full image, like page extraction, set this to false. - augmentation=False # To apply any kind of augmentation, this parameter must be set to true. - flip_aug=False # If true, different types of flipping will be applied to the image. Types of flips are defined with "flip_index" in train.py. - blur_aug=False # If true, different types of blurring will be applied to the image. Types of blur are defined with "blur_k" in train.py. - scaling=False # If true, scaling will be applied to the image. The amount of scaling is defined with "scales" in train.py. - binarization=False # If true, Otsu thresholding will be applied to augment the input with binarized images. - dir_train=None # Directory of training dataset with subdirectories having the names "images" and "labels". - dir_eval=None # Directory of validation dataset with subdirectories having the names "images" and "labels". - dir_output=None # Directory where the output model will be saved. - pretraining=False # Set to true to load pretrained weights of ResNet50 encoder. - scaling_bluring=False # If true, a combination of scaling and blurring will be applied to the image. - scaling_binarization=False # If true, a combination of scaling and binarization will be applied to the image. - scaling_flip=False # If true, a combination of scaling and flipping will be applied to the image. - thetha=[10,-10] # Rotate image by these angles for augmentation. - blur_k=['blur','gauss','median'] # Blur image for augmentation. - scales=[0.5,2] # Scale patches for augmentation. - flip_index=[0,1,-1] # Flip image for augmentation. - continue_training = False # Set to true if you would like to continue training an already trained a model. - index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3. - dir_of_start_model = '' # Directory containing pretrained encoder to continue training the model. - is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false. - weighted_loss = False # Use weighted categorical cross entropy as loss fucntion. When set to true, "is_loss_soft_dice" must be false. - data_is_provided = False # Only set this to true when you have already provided the input data and the train and eval data are in "dir_output". + n_classes = None # Number of classes. In the case of binary classification this should be 2. + n_epochs = 1 # Number of epochs. + input_height = 224 * 1 # Height of model's input in pixels. + input_width = 224 * 1 # Width of model's input in pixels. + weight_decay = 1e-6 # Weight decay of l2 regularization of model layers. + n_batch = 1 # Number of batches at each iteration. + learning_rate = 1e-4 # Set the learning rate. + patches = False # Divides input image into smaller patches (input size of the model) when set to true. For the model to see the full image, like page extraction, set this to false. + augmentation = False # To apply any kind of augmentation, this parameter must be set to true. + flip_aug = False # If true, different types of flipping will be applied to the image. Types of flips are defined with "flip_index" in train.py. + blur_aug = False # If true, different types of blurring will be applied to the image. Types of blur are defined with "blur_k" in train.py. + scaling = False # If true, scaling will be applied to the image. The amount of scaling is defined with "scales" in train.py. + binarization = False # If true, Otsu thresholding will be applied to augment the input with binarized images. + dir_train = None # Directory of training dataset with subdirectories having the names "images" and "labels". + dir_eval = None # Directory of validation dataset with subdirectories having the names "images" and "labels". + dir_output = None # Directory where the output model will be saved. + pretraining = False # Set to true to load pretrained weights of ResNet50 encoder. + scaling_bluring = False # If true, a combination of scaling and blurring will be applied to the image. + scaling_binarization = False # If true, a combination of scaling and binarization will be applied to the image. + scaling_flip = False # If true, a combination of scaling and flipping will be applied to the image. + thetha = [10, -10] # Rotate image by these angles for augmentation. + blur_k = ['blur', 'gauss', 'median'] # Blur image for augmentation. + scales = [0.5, 2] # Scale patches for augmentation. + flip_index = [0, 1, -1] # Flip image for augmentation. + continue_training = False # Set to true if you would like to continue training an already trained a model. + index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3. + dir_of_start_model = '' # Directory containing pretrained encoder to continue training the model. + is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false. + weighted_loss = False # Use weighted categorical cross entropy as loss fucntion. When set to true, "is_loss_soft_dice" must be false. + data_is_provided = False # Only set this to true when you have already provided the input data and the train and eval data are in "dir_output". + @ex.automain -def run(n_classes,n_epochs,input_height, - input_width,weight_decay,weighted_loss, - index_start,dir_of_start_model,is_loss_soft_dice, - n_batch,patches,augmentation,flip_aug - ,blur_aug,scaling, binarization, - blur_k,scales,dir_train,data_is_provided, - scaling_bluring,scaling_binarization,rotation, - rotation_not_90,thetha,scaling_flip,continue_training, - flip_index,dir_eval ,dir_output,pretraining,learning_rate): - - +def run(n_classes, n_epochs, input_height, + input_width, weight_decay, weighted_loss, + index_start, dir_of_start_model, is_loss_soft_dice, + n_batch, patches, augmentation, flip_aug, + blur_aug, scaling, binarization, + blur_k, scales, dir_train, data_is_provided, + scaling_bluring, scaling_binarization, rotation, + rotation_not_90, thetha, scaling_flip, continue_training, + flip_index, dir_eval, dir_output, pretraining, learning_rate): if data_is_provided: - dir_train_flowing=os.path.join(dir_output,'train') - dir_eval_flowing=os.path.join(dir_output,'eval') - - dir_flow_train_imgs=os.path.join(dir_train_flowing,'images') - dir_flow_train_labels=os.path.join(dir_train_flowing,'labels') - - dir_flow_eval_imgs=os.path.join(dir_eval_flowing,'images') - dir_flow_eval_labels=os.path.join(dir_eval_flowing,'labels') - + dir_train_flowing = os.path.join(dir_output, 'train') + dir_eval_flowing = os.path.join(dir_output, 'eval') + + dir_flow_train_imgs = os.path.join(dir_train_flowing, 'images') + dir_flow_train_labels = os.path.join(dir_train_flowing, 'labels') + + dir_flow_eval_imgs = os.path.join(dir_eval_flowing, 'images') + dir_flow_eval_labels = os.path.join(dir_eval_flowing, 'labels') + configuration() - + else: - dir_img,dir_seg=get_dirs_or_files(dir_train) - dir_img_val,dir_seg_val=get_dirs_or_files(dir_eval) - + dir_img, dir_seg = get_dirs_or_files(dir_train) + dir_img_val, dir_seg_val = get_dirs_or_files(dir_eval) + # make first a directory in output for both training and evaluations in order to flow data from these directories. - dir_train_flowing=os.path.join(dir_output,'train') - dir_eval_flowing=os.path.join(dir_output,'eval') - - dir_flow_train_imgs=os.path.join(dir_train_flowing,'images/') - dir_flow_train_labels=os.path.join(dir_train_flowing,'labels/') - - dir_flow_eval_imgs=os.path.join(dir_eval_flowing,'images/') - dir_flow_eval_labels=os.path.join(dir_eval_flowing,'labels/') - + dir_train_flowing = os.path.join(dir_output, 'train') + dir_eval_flowing = os.path.join(dir_output, 'eval') + + dir_flow_train_imgs = os.path.join(dir_train_flowing, 'images/') + dir_flow_train_labels = os.path.join(dir_train_flowing, 'labels/') + + dir_flow_eval_imgs = os.path.join(dir_eval_flowing, 'images/') + dir_flow_eval_labels = os.path.join(dir_eval_flowing, 'labels/') + if os.path.isdir(dir_train_flowing): - os.system('rm -rf '+dir_train_flowing) + os.system('rm -rf ' + dir_train_flowing) os.makedirs(dir_train_flowing) else: os.makedirs(dir_train_flowing) - + if os.path.isdir(dir_eval_flowing): - os.system('rm -rf '+dir_eval_flowing) + os.system('rm -rf ' + dir_eval_flowing) os.makedirs(dir_eval_flowing) else: os.makedirs(dir_eval_flowing) - os.mkdir(dir_flow_train_imgs) os.mkdir(dir_flow_train_labels) - + os.mkdir(dir_flow_eval_imgs) os.mkdir(dir_flow_eval_labels) - - - #set the gpu configuration + + # set the gpu configuration configuration() - - #writing patches into a sub-folder in order to be flowed from directory. - provide_patches(dir_img,dir_seg,dir_flow_train_imgs, + # writing patches into a sub-folder in order to be flowed from directory. + provide_patches(dir_img, dir_seg, dir_flow_train_imgs, dir_flow_train_labels, - input_height,input_width,blur_k,blur_aug, - flip_aug,binarization,scaling,scales,flip_index, - scaling_bluring,scaling_binarization,rotation, - rotation_not_90,thetha,scaling_flip, - augmentation=augmentation,patches=patches) - - provide_patches(dir_img_val,dir_seg_val,dir_flow_eval_imgs, - dir_flow_eval_labels, - input_height,input_width,blur_k,blur_aug, - flip_aug,binarization,scaling,scales,flip_index, - scaling_bluring,scaling_binarization,rotation, - rotation_not_90,thetha,scaling_flip, - augmentation=False,patches=patches) - + input_height, input_width, blur_k, blur_aug, + flip_aug, binarization, scaling, scales, flip_index, + scaling_bluring, scaling_binarization, rotation, + rotation_not_90, thetha, scaling_flip, + augmentation=augmentation, patches=patches) + + provide_patches(dir_img_val, dir_seg_val, dir_flow_eval_imgs, + dir_flow_eval_labels, + input_height, input_width, blur_k, blur_aug, + flip_aug, binarization, scaling, scales, flip_index, + scaling_bluring, scaling_binarization, rotation, + rotation_not_90, thetha, scaling_flip, + augmentation=False, patches=patches) - if weighted_loss: - weights=np.zeros(n_classes) + weights = np.zeros(n_classes) if data_is_provided: for obj in os.listdir(dir_flow_train_labels): try: - label_obj=cv2.imread(dir_flow_train_labels+'/'+obj) - label_obj_one_hot=get_one_hot( label_obj,label_obj.shape[0],label_obj.shape[1],n_classes) - weights+=(label_obj_one_hot.sum(axis=0)).sum(axis=0) + label_obj = cv2.imread(dir_flow_train_labels + '/' + obj) + label_obj_one_hot = get_one_hot(label_obj, label_obj.shape[0], label_obj.shape[1], n_classes) + weights += (label_obj_one_hot.sum(axis=0)).sum(axis=0) except: pass else: - + for obj in os.listdir(dir_seg): try: - label_obj=cv2.imread(dir_seg+'/'+obj) - label_obj_one_hot=get_one_hot( label_obj,label_obj.shape[0],label_obj.shape[1],n_classes) - weights+=(label_obj_one_hot.sum(axis=0)).sum(axis=0) + label_obj = cv2.imread(dir_seg + '/' + obj) + label_obj_one_hot = get_one_hot(label_obj, label_obj.shape[0], label_obj.shape[1], n_classes) + weights += (label_obj_one_hot.sum(axis=0)).sum(axis=0) except: pass - - weights=1.00/weights - - weights=weights/float(np.sum(weights)) - weights=weights/float(np.min(weights)) - weights=weights/float(np.sum(weights)) - - - + weights = 1.00 / weights + + weights = weights / float(np.sum(weights)) + weights = weights / float(np.min(weights)) + weights = weights / float(np.sum(weights)) + if continue_training: if is_loss_soft_dice: - model = load_model (dir_of_start_model, compile = True, custom_objects={'soft_dice_loss': soft_dice_loss}) + model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss}) if weighted_loss: - model = load_model (dir_of_start_model, compile = True, custom_objects={'loss': weighted_categorical_crossentropy(weights)}) + model = load_model(dir_of_start_model, compile=True, + custom_objects={'loss': weighted_categorical_crossentropy(weights)}) if not is_loss_soft_dice and not weighted_loss: - model = load_model (dir_of_start_model, compile = True) + model = load_model(dir_of_start_model, compile=True) else: - #get our model. + # get our model. index_start = 0 - model = resnet50_unet(n_classes, input_height, input_width,weight_decay,pretraining) - - #if you want to see the model structure just uncomment model summary. - #model.summary() - + model = resnet50_unet(n_classes, input_height, input_width, weight_decay, pretraining) + + # if you want to see the model structure just uncomment model summary. + # model.summary() if not is_loss_soft_dice and not weighted_loss: model.compile(loss='categorical_crossentropy', - optimizer = Adam(lr=learning_rate),metrics=['accuracy']) - if is_loss_soft_dice: + optimizer=Adam(lr=learning_rate), metrics=['accuracy']) + if is_loss_soft_dice: model.compile(loss=soft_dice_loss, - optimizer = Adam(lr=learning_rate),metrics=['accuracy']) - + optimizer=Adam(lr=learning_rate), metrics=['accuracy']) + if weighted_loss: model.compile(loss=weighted_categorical_crossentropy(weights), - optimizer = Adam(lr=learning_rate),metrics=['accuracy']) - - #generating train and evaluation data - train_gen = data_gen(dir_flow_train_imgs,dir_flow_train_labels, batch_size = n_batch, - input_height=input_height, input_width=input_width,n_classes=n_classes ) - val_gen = data_gen(dir_flow_eval_imgs,dir_flow_eval_labels, batch_size = n_batch, - input_height=input_height, input_width=input_width,n_classes=n_classes ) - - for i in tqdm(range(index_start, n_epochs+index_start)): + optimizer=Adam(lr=learning_rate), metrics=['accuracy']) + + # generating train and evaluation data + train_gen = data_gen(dir_flow_train_imgs, dir_flow_train_labels, batch_size=n_batch, + input_height=input_height, input_width=input_width, n_classes=n_classes) + val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, batch_size=n_batch, + input_height=input_height, input_width=input_width, n_classes=n_classes) + + for i in tqdm(range(index_start, n_epochs + index_start)): model.fit_generator( train_gen, - steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs))/n_batch)-1, + steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1, validation_data=val_gen, validation_steps=1, epochs=1) - model.save(dir_output+'/'+'model_'+str(i)) - - - #os.system('rm -rf '+dir_train_flowing) - #os.system('rm -rf '+dir_eval_flowing) - - #model.save(dir_output+'/'+'model'+'.h5') - - - - - - - - + model.save(dir_output + '/' + 'model_' + str(i)) + # os.system('rm -rf '+dir_train_flowing) + # os.system('rm -rf '+dir_eval_flowing) + # model.save(dir_output+'/'+'model'+'.h5') diff --git a/train/utils.py b/train/utils.py index 19ab46e..7c65f18 100644 --- a/train/utils.py +++ b/train/utils.py @@ -10,18 +10,17 @@ import imutils import math - -def bluring(img_in,kind): - if kind=='guass': - img_blur = cv2.GaussianBlur(img_in,(5,5),0) - elif kind=="median": - img_blur = cv2.medianBlur(img_in,5) - elif kind=='blur': - img_blur=cv2.blur(img_in,(5,5)) +def bluring(img_in, kind): + if kind == 'gauss': + img_blur = cv2.GaussianBlur(img_in, (5, 5), 0) + elif kind == "median": + img_blur = cv2.medianBlur(img_in, 5) + elif kind == 'blur': + img_blur = cv2.blur(img_in, (5, 5)) return img_blur -def elastic_transform(image, alpha, sigma,seedj, random_state=None): - + +def elastic_transform(image, alpha, sigma, seedj, random_state=None): """Elastic deformation of images as described in [Simard2003]_. .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for Convolutional Neural Networks applied to Visual Document Analysis", in @@ -37,461 +36,459 @@ def elastic_transform(image, alpha, sigma,seedj, random_state=None): dz = np.zeros_like(dx) x, y, z = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]), np.arange(shape[2])) - indices = np.reshape(y+dy, (-1, 1)), np.reshape(x+dx, (-1, 1)), np.reshape(z, (-1, 1)) + indices = np.reshape(y + dy, (-1, 1)), np.reshape(x + dx, (-1, 1)), np.reshape(z, (-1, 1)) distored_image = map_coordinates(image, indices, order=1, mode='reflect') return distored_image.reshape(image.shape) + def rotation_90(img): - img_rot=np.zeros((img.shape[1],img.shape[0],img.shape[2])) - img_rot[:,:,0]=img[:,:,0].T - img_rot[:,:,1]=img[:,:,1].T - img_rot[:,:,2]=img[:,:,2].T + img_rot = np.zeros((img.shape[1], img.shape[0], img.shape[2])) + img_rot[:, :, 0] = img[:, :, 0].T + img_rot[:, :, 1] = img[:, :, 1].T + img_rot[:, :, 2] = img[:, :, 2].T return img_rot + def rotatedRectWithMaxArea(w, h, angle): - """ + """ Given a rectangle of size wxh that has been rotated by 'angle' (in radians), computes the width and height of the largest possible axis-aligned rectangle (maximal area) within the rotated rectangle. """ - if w <= 0 or h <= 0: - return 0,0 + if w <= 0 or h <= 0: + return 0, 0 - width_is_longer = w >= h - side_long, side_short = (w,h) if width_is_longer else (h,w) + width_is_longer = w >= h + side_long, side_short = (w, h) if width_is_longer else (h, w) - # since the solutions for angle, -angle and 180-angle are all the same, - # if suffices to look at the first quadrant and the absolute values of sin,cos: - sin_a, cos_a = abs(math.sin(angle)), abs(math.cos(angle)) - if side_short <= 2.*sin_a*cos_a*side_long or abs(sin_a-cos_a) < 1e-10: - # half constrained case: two crop corners touch the longer side, - # the other two corners are on the mid-line parallel to the longer line - x = 0.5*side_short - wr,hr = (x/sin_a,x/cos_a) if width_is_longer else (x/cos_a,x/sin_a) - else: - # fully constrained case: crop touches all 4 sides - cos_2a = cos_a*cos_a - sin_a*sin_a - wr,hr = (w*cos_a - h*sin_a)/cos_2a, (h*cos_a - w*sin_a)/cos_2a + # since the solutions for angle, -angle and 180-angle are all the same, + # if suffices to look at the first quadrant and the absolute values of sin,cos: + sin_a, cos_a = abs(math.sin(angle)), abs(math.cos(angle)) + if side_short <= 2. * sin_a * cos_a * side_long or abs(sin_a - cos_a) < 1e-10: + # half constrained case: two crop corners touch the longer side, + # the other two corners are on the mid-line parallel to the longer line + x = 0.5 * side_short + wr, hr = (x / sin_a, x / cos_a) if width_is_longer else (x / cos_a, x / sin_a) + else: + # fully constrained case: crop touches all 4 sides + cos_2a = cos_a * cos_a - sin_a * sin_a + wr, hr = (w * cos_a - h * sin_a) / cos_2a, (h * cos_a - w * sin_a) / cos_2a - return wr,hr + return wr, hr -def rotate_max_area(image,rotated, rotated_label,angle): + +def rotate_max_area(image, rotated, rotated_label, angle): """ image: cv2 image matrix object angle: in degree """ wr, hr = rotatedRectWithMaxArea(image.shape[1], image.shape[0], math.radians(angle)) h, w, _ = rotated.shape - y1 = h//2 - int(hr/2) + y1 = h // 2 - int(hr / 2) y2 = y1 + int(hr) - x1 = w//2 - int(wr/2) + x1 = w // 2 - int(wr / 2) x2 = x1 + int(wr) - return rotated[y1:y2, x1:x2],rotated_label[y1:y2, x1:x2] -def rotation_not_90_func(img,label,thetha): - rotated=imutils.rotate(img,thetha) - rotated_label=imutils.rotate(label,thetha) - return rotate_max_area(img, rotated,rotated_label,thetha) + return rotated[y1:y2, x1:x2], rotated_label[y1:y2, x1:x2] + + +def rotation_not_90_func(img, label, thetha): + rotated = imutils.rotate(img, thetha) + rotated_label = imutils.rotate(label, thetha) + return rotate_max_area(img, rotated, rotated_label, thetha) + def color_images(seg, n_classes): - ann_u=range(n_classes) - if len(np.shape(seg))==3: - seg=seg[:,:,0] - - seg_img=np.zeros((np.shape(seg)[0],np.shape(seg)[1],3)).astype(float) - colors=sns.color_palette("hls", n_classes) - + ann_u = range(n_classes) + if len(np.shape(seg)) == 3: + seg = seg[:, :, 0] + + seg_img = np.zeros((np.shape(seg)[0], np.shape(seg)[1], 3)).astype(float) + colors = sns.color_palette("hls", n_classes) + for c in ann_u: - c=int(c) - segl=(seg==c) - seg_img[:,:,0]+=segl*(colors[c][0]) - seg_img[:,:,1]+=segl*(colors[c][1]) - seg_img[:,:,2]+=segl*(colors[c][2]) + c = int(c) + segl = (seg == c) + seg_img[:, :, 0] += segl * (colors[c][0]) + seg_img[:, :, 1] += segl * (colors[c][1]) + seg_img[:, :, 2] += segl * (colors[c][2]) return seg_img - -def resize_image(seg_in,input_height,input_width): - return cv2.resize(seg_in,(input_width,input_height),interpolation=cv2.INTER_NEAREST) -def get_one_hot(seg,input_height,input_width,n_classes): - seg=seg[:,:,0] - seg_f=np.zeros((input_height, input_width,n_classes)) + +def resize_image(seg_in, input_height, input_width): + return cv2.resize(seg_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST) + + +def get_one_hot(seg, input_height, input_width, n_classes): + seg = seg[:, :, 0] + seg_f = np.zeros((input_height, input_width, n_classes)) for j in range(n_classes): - seg_f[:,:,j]=(seg==j).astype(int) + seg_f[:, :, j] = (seg == j).astype(int) return seg_f - -def IoU(Yi,y_predi): + +def IoU(Yi, y_predi): ## mean Intersection over Union ## Mean IoU = TP/(FN + TP + FP) IoUs = [] - classes_true=np.unique(Yi) + classes_true = np.unique(Yi) for c in classes_true: - TP = np.sum( (Yi == c)&(y_predi==c) ) - FP = np.sum( (Yi != c)&(y_predi==c) ) - FN = np.sum( (Yi == c)&(y_predi != c)) - IoU = TP/float(TP + FP + FN) - print("class {:02.0f}: #TP={:6.0f}, #FP={:6.0f}, #FN={:5.0f}, IoU={:4.3f}".format(c,TP,FP,FN,IoU)) + TP = np.sum((Yi == c) & (y_predi == c)) + FP = np.sum((Yi != c) & (y_predi == c)) + FN = np.sum((Yi == c) & (y_predi != c)) + IoU = TP / float(TP + FP + FN) + print("class {:02.0f}: #TP={:6.0f}, #FP={:6.0f}, #FN={:5.0f}, IoU={:4.3f}".format(c, TP, FP, FN, IoU)) IoUs.append(IoU) mIoU = np.mean(IoUs) print("_________________") print("Mean IoU: {:4.3f}".format(mIoU)) return mIoU -def data_gen(img_folder, mask_folder, batch_size,input_height, input_width,n_classes): + + +def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_classes): c = 0 - n = [f for f in os.listdir(img_folder) if not f.startswith('.')]# os.listdir(img_folder) #List of training images + n = [f for f in os.listdir(img_folder) if not f.startswith('.')] # os.listdir(img_folder) #List of training images random.shuffle(n) while True: img = np.zeros((batch_size, input_height, input_width, 3)).astype('float') mask = np.zeros((batch_size, input_height, input_width, n_classes)).astype('float') - - for i in range(c, c+batch_size): #initially from 0 to 16, c = 0. - #print(img_folder+'/'+n[i]) - - try: - filename=n[i].split('.')[0] - - train_img = cv2.imread(img_folder+'/'+n[i])/255. - train_img = cv2.resize(train_img, (input_width, input_height),interpolation=cv2.INTER_NEAREST)# Read an image from folder and resize - - img[i-c] = train_img #add to array - img[0], img[1], and so on. - train_mask = cv2.imread(mask_folder+'/'+filename+'.png') - #print(mask_folder+'/'+filename+'.png') - #print(train_mask.shape) - train_mask = get_one_hot( resize_image(train_mask,input_height,input_width),input_height,input_width,n_classes) - #train_mask = train_mask.reshape(224, 224, 1) # Add extra dimension for parity with train_img size [512 * 512 * 3] - - mask[i-c] = train_mask - except: - img[i-c] = np.ones((input_height, input_width, 3)).astype('float') - mask[i-c] = np.zeros((input_height, input_width, n_classes)).astype('float') - - - c+=batch_size - if(c+batch_size>=len(os.listdir(img_folder))): - c=0 + for i in range(c, c + batch_size): # initially from 0 to 16, c = 0. + # print(img_folder+'/'+n[i]) + + try: + filename = n[i].split('.')[0] + + train_img = cv2.imread(img_folder + '/' + n[i]) / 255. + train_img = cv2.resize(train_img, (input_width, input_height), + interpolation=cv2.INTER_NEAREST) # Read an image from folder and resize + + img[i - c] = train_img # add to array - img[0], img[1], and so on. + train_mask = cv2.imread(mask_folder + '/' + filename + '.png') + # print(mask_folder+'/'+filename+'.png') + # print(train_mask.shape) + train_mask = get_one_hot(resize_image(train_mask, input_height, input_width), input_height, input_width, + n_classes) + # train_mask = train_mask.reshape(224, 224, 1) # Add extra dimension for parity with train_img size [512 * 512 * 3] + + mask[i - c] = train_mask + except: + img[i - c] = np.ones((input_height, input_width, 3)).astype('float') + mask[i - c] = np.zeros((input_height, input_width, n_classes)).astype('float') + + c += batch_size + if c + batch_size >= len(os.listdir(img_folder)): + c = 0 random.shuffle(n) yield img, mask - + + def otsu_copy(img): - img_r=np.zeros(img.shape) - img1=img[:,:,0] - img2=img[:,:,1] - img3=img[:,:,2] - _, threshold1 = cv2.threshold(img1, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) - _, threshold2 = cv2.threshold(img2, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) - _, threshold3 = cv2.threshold(img3, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) - img_r[:,:,0]=threshold1 - img_r[:,:,1]=threshold1 - img_r[:,:,2]=threshold1 + img_r = np.zeros(img.shape) + img1 = img[:, :, 0] + img2 = img[:, :, 1] + img3 = img[:, :, 2] + _, threshold1 = cv2.threshold(img1, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + _, threshold2 = cv2.threshold(img2, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + _, threshold3 = cv2.threshold(img3, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + img_r[:, :, 0] = threshold1 + img_r[:, :, 1] = threshold1 + img_r[:, :, 2] = threshold1 return img_r -def get_patches(dir_img_f,dir_seg_f,img,label,height,width,indexer): - if img.shape[0]int(nxf): - nxf=int(nxf)+1 - if nyf>int(nyf): - nyf=int(nyf)+1 - - nxf=int(nxf) - nyf=int(nyf) - + +def get_patches(dir_img_f, dir_seg_f, img, label, height, width, indexer): + if img.shape[0] < height or img.shape[1] < width: + img, label = do_padding(img, label, height, width) + + img_h = img.shape[0] + img_w = img.shape[1] + + nxf = img_w / float(width) + nyf = img_h / float(height) + + if nxf > int(nxf): + nxf = int(nxf) + 1 + if nyf > int(nyf): + nyf = int(nyf) + 1 + + nxf = int(nxf) + nyf = int(nyf) + for i in range(nxf): for j in range(nyf): - index_x_d=i*width - index_x_u=(i+1)*width - - index_y_d=j*height - index_y_u=(j+1)*height - - if index_x_u>img_w: - index_x_u=img_w - index_x_d=img_w-width - if index_y_u>img_h: - index_y_u=img_h - index_y_d=img_h-height - - - img_patch=img[index_y_d:index_y_u,index_x_d:index_x_u,:] - label_patch=label[index_y_d:index_y_u,index_x_d:index_x_u,:] - - cv2.imwrite(dir_img_f+'/img_'+str(indexer)+'.png', img_patch ) - cv2.imwrite(dir_seg_f+'/img_'+str(indexer)+'.png' , label_patch ) - indexer+=1 - - return indexer + index_x_d = i * width + index_x_u = (i + 1) * width -def do_padding(img,label,height,width): - - height_new=img.shape[0] - width_new=img.shape[1] - - h_start=0 - w_start=0 - - if img.shape[0]int(nxf): - nxf=int(nxf)+1 - if nyf>int(nyf): - nyf=int(nyf)+1 - - nxf=int(nxf) - nyf=int(nyf) - - for i in range(nxf): - for j in range(nyf): - index_x_d=i*width_scale - index_x_u=(i+1)*width_scale - - index_y_d=j*height_scale - index_y_u=(j+1)*height_scale - - if index_x_u>img_w: - index_x_u=img_w - index_x_d=img_w-width_scale - if index_y_u>img_h: - index_y_u=img_h - index_y_d=img_h-height_scale - - - img_patch=img[index_y_d:index_y_u,index_x_d:index_x_u,:] - label_patch=label[index_y_d:index_y_u,index_x_d:index_x_u,:] - - img_patch=resize_image(img_patch,height,width) - label_patch=resize_image(label_patch,height,width) - - cv2.imwrite(dir_img_f+'/img_'+str(indexer)+'.png', img_patch ) - cv2.imwrite(dir_seg_f+'/img_'+str(indexer)+'.png' , label_patch ) - indexer+=1 + index_y_d = j * height + index_y_u = (j + 1) * height - return indexer + if index_x_u > img_w: + index_x_u = img_w + index_x_d = img_w - width + if index_y_u > img_h: + index_y_u = img_h + index_y_d = img_h - height -def get_patches_num_scale_new(dir_img_f,dir_seg_f,img,label,height,width,indexer,scaler): - img=resize_image(img,int(img.shape[0]*scaler),int(img.shape[1]*scaler)) - label=resize_image(label,int(label.shape[0]*scaler),int(label.shape[1]*scaler)) - - if img.shape[0]int(nxf): - nxf=int(nxf)+1 - if nyf>int(nyf): - nyf=int(nyf)+1 - - nxf=int(nxf) - nyf=int(nyf) - - for i in range(nxf): - for j in range(nyf): - index_x_d=i*width_scale - index_x_u=(i+1)*width_scale - - index_y_d=j*height_scale - index_y_u=(j+1)*height_scale - - if index_x_u>img_w: - index_x_u=img_w - index_x_d=img_w-width_scale - if index_y_u>img_h: - index_y_u=img_h - index_y_d=img_h-height_scale - - - img_patch=img[index_y_d:index_y_u,index_x_d:index_x_u,:] - label_patch=label[index_y_d:index_y_u,index_x_d:index_x_u,:] - - #img_patch=resize_image(img_patch,height,width) - #label_patch=resize_image(label_patch,height,width) - - cv2.imwrite(dir_img_f+'/img_'+str(indexer)+'.png', img_patch ) - cv2.imwrite(dir_seg_f+'/img_'+str(indexer)+'.png' , label_patch ) - indexer+=1 + img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :] + label_patch = label[index_y_d:index_y_u, index_x_d:index_x_u, :] + + cv2.imwrite(dir_img_f + '/img_' + str(indexer) + '.png', img_patch) + cv2.imwrite(dir_seg_f + '/img_' + str(indexer) + '.png', label_patch) + indexer += 1 return indexer -def provide_patches(dir_img,dir_seg,dir_flow_train_imgs, +def do_padding(img, label, height, width): + height_new = img.shape[0] + width_new = img.shape[1] + + h_start = 0 + w_start = 0 + + if img.shape[0] < height: + h_start = int(abs(height - img.shape[0]) / 2.) + height_new = height + + if img.shape[1] < width: + w_start = int(abs(width - img.shape[1]) / 2.) + width_new = width + + img_new = np.ones((height_new, width_new, img.shape[2])).astype(float) * 255 + label_new = np.zeros((height_new, width_new, label.shape[2])).astype(float) + + img_new[h_start:h_start + img.shape[0], w_start:w_start + img.shape[1], :] = np.copy(img[:, :, :]) + label_new[h_start:h_start + label.shape[0], w_start:w_start + label.shape[1], :] = np.copy(label[:, :, :]) + + return img_new, label_new + + +def get_patches_num_scale(dir_img_f, dir_seg_f, img, label, height, width, indexer, n_patches, scaler): + if img.shape[0] < height or img.shape[1] < width: + img, label = do_padding(img, label, height, width) + + img_h = img.shape[0] + img_w = img.shape[1] + + height_scale = int(height * scaler) + width_scale = int(width * scaler) + + nxf = img_w / float(width_scale) + nyf = img_h / float(height_scale) + + if nxf > int(nxf): + nxf = int(nxf) + 1 + if nyf > int(nyf): + nyf = int(nyf) + 1 + + nxf = int(nxf) + nyf = int(nyf) + + for i in range(nxf): + for j in range(nyf): + index_x_d = i * width_scale + index_x_u = (i + 1) * width_scale + + index_y_d = j * height_scale + index_y_u = (j + 1) * height_scale + + if index_x_u > img_w: + index_x_u = img_w + index_x_d = img_w - width_scale + if index_y_u > img_h: + index_y_u = img_h + index_y_d = img_h - height_scale + + img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :] + label_patch = label[index_y_d:index_y_u, index_x_d:index_x_u, :] + + img_patch = resize_image(img_patch, height, width) + label_patch = resize_image(label_patch, height, width) + + cv2.imwrite(dir_img_f + '/img_' + str(indexer) + '.png', img_patch) + cv2.imwrite(dir_seg_f + '/img_' + str(indexer) + '.png', label_patch) + indexer += 1 + + return indexer + + +def get_patches_num_scale_new(dir_img_f, dir_seg_f, img, label, height, width, indexer, scaler): + img = resize_image(img, int(img.shape[0] * scaler), int(img.shape[1] * scaler)) + label = resize_image(label, int(label.shape[0] * scaler), int(label.shape[1] * scaler)) + + if img.shape[0] < height or img.shape[1] < width: + img, label = do_padding(img, label, height, width) + + img_h = img.shape[0] + img_w = img.shape[1] + + height_scale = int(height * 1) + width_scale = int(width * 1) + + nxf = img_w / float(width_scale) + nyf = img_h / float(height_scale) + + if nxf > int(nxf): + nxf = int(nxf) + 1 + if nyf > int(nyf): + nyf = int(nyf) + 1 + + nxf = int(nxf) + nyf = int(nyf) + + for i in range(nxf): + for j in range(nyf): + index_x_d = i * width_scale + index_x_u = (i + 1) * width_scale + + index_y_d = j * height_scale + index_y_u = (j + 1) * height_scale + + if index_x_u > img_w: + index_x_u = img_w + index_x_d = img_w - width_scale + if index_y_u > img_h: + index_y_u = img_h + index_y_d = img_h - height_scale + + img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :] + label_patch = label[index_y_d:index_y_u, index_x_d:index_x_u, :] + + # img_patch=resize_image(img_patch,height,width) + # label_patch=resize_image(label_patch,height,width) + + cv2.imwrite(dir_img_f + '/img_' + str(indexer) + '.png', img_patch) + cv2.imwrite(dir_seg_f + '/img_' + str(indexer) + '.png', label_patch) + indexer += 1 + + return indexer + + +def provide_patches(dir_img, dir_seg, dir_flow_train_imgs, dir_flow_train_labels, - input_height,input_width,blur_k,blur_aug, - flip_aug,binarization,scaling,scales,flip_index, - scaling_bluring,scaling_binarization,rotation, - rotation_not_90,thetha,scaling_flip, - augmentation=False,patches=False): - - imgs_cv_train=np.array(os.listdir(dir_img)) - segs_cv_train=np.array(os.listdir(dir_seg)) - - indexer=0 - for im, seg_i in tqdm(zip(imgs_cv_train,segs_cv_train)): - img_name=im.split('.')[0] + input_height, input_width, blur_k, blur_aug, + flip_aug, binarization, scaling, scales, flip_index, + scaling_bluring, scaling_binarization, rotation, + rotation_not_90, thetha, scaling_flip, + augmentation=False, patches=False): + imgs_cv_train = np.array(os.listdir(dir_img)) + segs_cv_train = np.array(os.listdir(dir_seg)) + + indexer = 0 + for im, seg_i in tqdm(zip(imgs_cv_train, segs_cv_train)): + img_name = im.split('.')[0] if not patches: - cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', resize_image(cv2.imread(dir_img+'/'+im),input_height,input_width ) ) - cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png' , resize_image(cv2.imread(dir_seg+'/'+img_name+'.png'),input_height,input_width ) ) - indexer+=1 - + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', + resize_image(cv2.imread(dir_img + '/' + im), input_height, input_width)) + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', + resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width)) + indexer += 1 + if augmentation: if flip_aug: for f_i in flip_index: - cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', - resize_image(cv2.flip(cv2.imread(dir_img+'/'+im),f_i),input_height,input_width) ) - - cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png' , - resize_image(cv2.flip(cv2.imread(dir_seg+'/'+img_name+'.png'),f_i),input_height,input_width) ) - indexer+=1 - - if blur_aug: - for blur_i in blur_k: - cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', - (resize_image(bluring(cv2.imread(dir_img+'/'+im),blur_i),input_height,input_width) ) ) - - cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png' , - resize_image(cv2.imread(dir_seg+'/'+img_name+'.png'),input_height,input_width) ) - indexer+=1 - - - if binarization: - cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', - resize_image(otsu_copy( cv2.imread(dir_img+'/'+im)),input_height,input_width )) - - cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png', - resize_image( cv2.imread(dir_seg+'/'+img_name+'.png'),input_height,input_width )) - indexer+=1 - - - + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', + resize_image(cv2.flip(cv2.imread(dir_img + '/' + im), f_i), input_height, + input_width)) + + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', + resize_image(cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), f_i), + input_height, input_width)) + indexer += 1 + + if blur_aug: + for blur_i in blur_k: + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', + (resize_image(bluring(cv2.imread(dir_img + '/' + im), blur_i), input_height, + input_width))) + + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', + resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, + input_width)) + indexer += 1 + + if binarization: + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', + resize_image(otsu_copy(cv2.imread(dir_img + '/' + im)), input_height, input_width)) + + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', + resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width)) + indexer += 1 - - if patches: - - indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, - cv2.imread(dir_img+'/'+im),cv2.imread(dir_seg+'/'+img_name+'.png'), - input_height,input_width,indexer=indexer) - + + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + cv2.imread(dir_img + '/' + im), cv2.imread(dir_seg + '/' + img_name + '.png'), + input_height, input_width, indexer=indexer) + if augmentation: - + if rotation: - - - indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, - rotation_90( cv2.imread(dir_img+'/'+im) ), - rotation_90( cv2.imread(dir_seg+'/'+img_name+'.png') ), - input_height,input_width,indexer=indexer) - + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + rotation_90(cv2.imread(dir_img + '/' + im)), + rotation_90(cv2.imread(dir_seg + '/' + img_name + '.png')), + input_height, input_width, indexer=indexer) + if rotation_not_90: - + for thetha_i in thetha: - img_max_rotated,label_max_rotated=rotation_not_90_func(cv2.imread(dir_img+'/'+im),cv2.imread(dir_seg+'/'+img_name+'.png'),thetha_i) - indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, - img_max_rotated, - label_max_rotated, - input_height,input_width,indexer=indexer) + img_max_rotated, label_max_rotated = rotation_not_90_func(cv2.imread(dir_img + '/' + im), + cv2.imread( + dir_seg + '/' + img_name + '.png'), + thetha_i) + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + img_max_rotated, + label_max_rotated, + input_height, input_width, indexer=indexer) if flip_aug: for f_i in flip_index: - indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, - cv2.flip( cv2.imread(dir_img+'/'+im) , f_i), - cv2.flip( cv2.imread(dir_seg+'/'+img_name+'.png') ,f_i), - input_height,input_width,indexer=indexer) - if blur_aug: + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + cv2.flip(cv2.imread(dir_img + '/' + im), f_i), + cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), f_i), + input_height, input_width, indexer=indexer) + if blur_aug: for blur_i in blur_k: + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + bluring(cv2.imread(dir_img + '/' + im), blur_i), + cv2.imread(dir_seg + '/' + img_name + '.png'), + input_height, input_width, indexer=indexer) - indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, - bluring( cv2.imread(dir_img+'/'+im) , blur_i), - cv2.imread(dir_seg+'/'+img_name+'.png'), - input_height,input_width,indexer=indexer) - - - if scaling: + if scaling: for sc_ind in scales: - indexer=get_patches_num_scale_new(dir_flow_train_imgs,dir_flow_train_labels, - cv2.imread(dir_img+'/'+im) , - cv2.imread(dir_seg+'/'+img_name+'.png'), - input_height,input_width,indexer=indexer,scaler=sc_ind) + indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels, + cv2.imread(dir_img + '/' + im), + cv2.imread(dir_seg + '/' + img_name + '.png'), + input_height, input_width, indexer=indexer, scaler=sc_ind) if binarization: - indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, - otsu_copy( cv2.imread(dir_img+'/'+im)), - cv2.imread(dir_seg+'/'+img_name+'.png'), - input_height,input_width,indexer=indexer) - - - - if scaling_bluring: + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + otsu_copy(cv2.imread(dir_img + '/' + im)), + cv2.imread(dir_seg + '/' + img_name + '.png'), + input_height, input_width, indexer=indexer) + + if scaling_bluring: for sc_ind in scales: for blur_i in blur_k: - indexer=get_patches_num_scale_new(dir_flow_train_imgs,dir_flow_train_labels, - bluring( cv2.imread(dir_img+'/'+im) , blur_i) , - cv2.imread(dir_seg+'/'+img_name+'.png') , - input_height,input_width,indexer=indexer,scaler=sc_ind) + indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels, + bluring(cv2.imread(dir_img + '/' + im), blur_i), + cv2.imread(dir_seg + '/' + img_name + '.png'), + input_height, input_width, indexer=indexer, + scaler=sc_ind) - if scaling_binarization: + if scaling_binarization: for sc_ind in scales: - indexer=get_patches_num_scale_new(dir_flow_train_imgs,dir_flow_train_labels, - otsu_copy( cv2.imread(dir_img+'/'+im)) , - cv2.imread(dir_seg+'/'+img_name+'.png'), - input_height,input_width,indexer=indexer,scaler=sc_ind) - - if scaling_flip: + indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels, + otsu_copy(cv2.imread(dir_img + '/' + im)), + cv2.imread(dir_seg + '/' + img_name + '.png'), + input_height, input_width, indexer=indexer, scaler=sc_ind) + + if scaling_flip: for sc_ind in scales: for f_i in flip_index: - indexer=get_patches_num_scale_new(dir_flow_train_imgs,dir_flow_train_labels, - cv2.flip( cv2.imread(dir_img+'/'+im) , f_i) , - cv2.flip(cv2.imread(dir_seg+'/'+img_name+'.png') ,f_i) , - input_height,input_width,indexer=indexer,scaler=sc_ind) - - - - - - - + indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels, + cv2.flip(cv2.imread(dir_img + '/' + im), f_i), + cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), + f_i), + input_height, input_width, indexer=indexer, + scaler=sc_ind) From 6e06742e66be00aba83919a3d49774ed1f54c790 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Tue, 16 Apr 2024 01:00:48 +0200 Subject: [PATCH 039/123] first working update of branch --- train/config_params.json | 19 ++- train/models.py | 179 +++++++++++++++++++++++++ train/train.py | 132 ++++++++++++------- train/utils.py | 273 +++++++++++++++++++++++++-------------- 4 files changed, 452 insertions(+), 151 deletions(-) diff --git a/train/config_params.json b/train/config_params.json index 7505a81..bd47a52 100644 --- a/train/config_params.json +++ b/train/config_params.json @@ -1,8 +1,9 @@ { - "n_classes" : 3, + "model_name" : "hybrid_transformer_cnn", + "n_classes" : 2, "n_epochs" : 2, "input_height" : 448, - "input_width" : 672, + "input_width" : 448, "weight_decay" : 1e-6, "n_batch" : 2, "learning_rate": 1e-4, @@ -18,13 +19,21 @@ "scaling_flip" : false, "rotation": false, "rotation_not_90": false, + "num_patches_xy": [28, 28], + "transformer_patchsize": 1, + "blur_k" : ["blur","guass","median"], + "scales" : [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.4], + "brightness" : [1.3, 1.5, 1.7, 2], + "degrade_scales" : [0.2, 0.4], + "flip_index" : [0, 1, -1], + "thetha" : [10, -10], "continue_training": false, - "index_start": 0, - "dir_of_start_model": " ", + "index_start" : 0, + "dir_of_start_model" : " ", "weighted_loss": false, "is_loss_soft_dice": false, "data_is_provided": false, "dir_train": "/train", "dir_eval": "/eval", - "dir_output": "/output" + "dir_output": "/out" } diff --git a/train/models.py b/train/models.py index f06823e..f7a7ad8 100644 --- a/train/models.py +++ b/train/models.py @@ -1,13 +1,81 @@ +import tensorflow as tf +from tensorflow import keras from tensorflow.keras.models import * from tensorflow.keras.layers import * from tensorflow.keras import layers from tensorflow.keras.regularizers import l2 +mlp_head_units = [2048, 1024] +projection_dim = 64 +transformer_layers = 8 +num_heads = 4 resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5' IMAGE_ORDERING = 'channels_last' MERGE_AXIS = -1 +transformer_units = [ + projection_dim * 2, + projection_dim, +] # Size of the transformer layers +def mlp(x, hidden_units, dropout_rate): + for units in hidden_units: + x = layers.Dense(units, activation=tf.nn.gelu)(x) + x = layers.Dropout(dropout_rate)(x) + return x + +class Patches(layers.Layer): + def __init__(self, patch_size):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs): + super(Patches, self).__init__() + self.patch_size = patch_size + + def call(self, images): + print(tf.shape(images)[1],'images') + print(self.patch_size,'self.patch_size') + batch_size = tf.shape(images)[0] + patches = tf.image.extract_patches( + images=images, + sizes=[1, self.patch_size, self.patch_size, 1], + strides=[1, self.patch_size, self.patch_size, 1], + rates=[1, 1, 1, 1], + padding="VALID", + ) + patch_dims = patches.shape[-1] + print(patches.shape,patch_dims,'patch_dims') + patches = tf.reshape(patches, [batch_size, -1, patch_dims]) + return patches + def get_config(self): + + config = super().get_config().copy() + config.update({ + 'patch_size': self.patch_size, + }) + return config + +class PatchEncoder(layers.Layer): + def __init__(self, num_patches, projection_dim): + super(PatchEncoder, self).__init__() + self.num_patches = num_patches + self.projection = layers.Dense(units=projection_dim) + self.position_embedding = layers.Embedding( + input_dim=num_patches, output_dim=projection_dim + ) + + def call(self, patch): + positions = tf.range(start=0, limit=self.num_patches, delta=1) + encoded = self.projection(patch) + self.position_embedding(positions) + return encoded + def get_config(self): + + config = super().get_config().copy() + config.update({ + 'num_patches': self.num_patches, + 'projection': self.projection, + 'position_embedding': self.position_embedding, + }) + return config + + def one_side_pad(x): x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x) if IMAGE_ORDERING == 'channels_first': @@ -292,3 +360,114 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, weight_decay=1e- model = Model(img_input, o) return model + + +def vit_resnet50_unet(n_classes,patch_size, num_patches, input_height=224,input_width=224,weight_decay=1e-6,pretraining=False): + inputs = layers.Input(shape=(input_height, input_width, 3)) + IMAGE_ORDERING = 'channels_last' + bn_axis=3 + + x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(inputs) + x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x) + f1 = x + + x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) + x = Activation('relu')(x) + x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x) + + x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) + x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') + x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') + f2 = one_side_pad(x) + + x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') + f3 = x + + x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') + f4 = x + + x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') + f5 = x + + if pretraining: + model = keras.Model(inputs, x).load_weights(resnet50_Weights_path) + + num_patches = x.shape[1]*x.shape[2] + patches = Patches(patch_size)(x) + # Encode patches. + encoded_patches = PatchEncoder(num_patches, projection_dim)(patches) + + for _ in range(transformer_layers): + # Layer normalization 1. + x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches) + # Create a multi-head attention layer. + attention_output = layers.MultiHeadAttention( + num_heads=num_heads, key_dim=projection_dim, dropout=0.1 + )(x1, x1) + # Skip connection 1. + x2 = layers.Add()([attention_output, encoded_patches]) + # Layer normalization 2. + x3 = layers.LayerNormalization(epsilon=1e-6)(x2) + # MLP. + x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1) + # Skip connection 2. + encoded_patches = layers.Add()([x3, x2]) + + encoded_patches = tf.reshape(encoded_patches, [-1, x.shape[1], x.shape[2], 64]) + + v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(encoded_patches) + v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048) + v1024_2048 = Activation('relu')(v1024_2048) + + o = (UpSampling2D( (2, 2), data_format=IMAGE_ORDERING))(v1024_2048) + o = (concatenate([o, f4],axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o ,f3], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, f2], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, f1], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, inputs],axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = (Activation('softmax'))(o) + + model = keras.Model(inputs=inputs, outputs=o) + + return model diff --git a/train/train.py b/train/train.py index 03faf46..6e6a172 100644 --- a/train/train.py +++ b/train/train.py @@ -10,6 +10,7 @@ from utils import * from metrics import * from tensorflow.keras.models import load_model from tqdm import tqdm +import json def configuration(): @@ -42,9 +43,13 @@ def config_params(): learning_rate = 1e-4 # Set the learning rate. patches = False # Divides input image into smaller patches (input size of the model) when set to true. For the model to see the full image, like page extraction, set this to false. augmentation = False # To apply any kind of augmentation, this parameter must be set to true. - flip_aug = False # If true, different types of flipping will be applied to the image. Types of flips are defined with "flip_index" in train.py. - blur_aug = False # If true, different types of blurring will be applied to the image. Types of blur are defined with "blur_k" in train.py. - scaling = False # If true, scaling will be applied to the image. The amount of scaling is defined with "scales" in train.py. + flip_aug = False # If true, different types of flipping will be applied to the image. Types of flips are defined with "flip_index" in config_params.json. + blur_aug = False # If true, different types of blurring will be applied to the image. Types of blur are defined with "blur_k" in config_params.json. + padding_white = False # If true, white padding will be applied to the image. + padding_black = False # If true, black padding will be applied to the image. + scaling = False # If true, scaling will be applied to the image. The amount of scaling is defined with "scales" in config_params.json. + degrading = False # If true, degrading will be applied to the image. The amount of degrading is defined with "degrade_scales" in config_params.json. + brightening = False # If true, brightening will be applied to the image. The amount of brightening is defined with "brightness" in config_params.json. binarization = False # If true, Otsu thresholding will be applied to augment the input with binarized images. dir_train = None # Directory of training dataset with subdirectories having the names "images" and "labels". dir_eval = None # Directory of validation dataset with subdirectories having the names "images" and "labels". @@ -52,13 +57,18 @@ def config_params(): pretraining = False # Set to true to load pretrained weights of ResNet50 encoder. scaling_bluring = False # If true, a combination of scaling and blurring will be applied to the image. scaling_binarization = False # If true, a combination of scaling and binarization will be applied to the image. + scaling_brightness = False # If true, a combination of scaling and brightening will be applied to the image. scaling_flip = False # If true, a combination of scaling and flipping will be applied to the image. - thetha = [10, -10] # Rotate image by these angles for augmentation. - blur_k = ['blur', 'gauss', 'median'] # Blur image for augmentation. - scales = [0.5, 2] # Scale patches for augmentation. - flip_index = [0, 1, -1] # Flip image for augmentation. + thetha = None # Rotate image by these angles for augmentation. + blur_k = None # Blur image for augmentation. + scales = None # Scale patches for augmentation. + degrade_scales = None # Degrade image for augmentation. + brightness = None # Brighten image for augmentation. + flip_index = None # Flip image for augmentation. continue_training = False # Set to true if you would like to continue training an already trained a model. - index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3. + transformer_patchsize = None # Patch size of vision transformer patches. + num_patches_xy = None # Number of patches for vision transformer. + index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3. dir_of_start_model = '' # Directory containing pretrained encoder to continue training the model. is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false. weighted_loss = False # Use weighted categorical cross entropy as loss fucntion. When set to true, "is_loss_soft_dice" must be false. @@ -66,15 +76,19 @@ def config_params(): @ex.automain -def run(n_classes, n_epochs, input_height, +def run(_config, n_classes, n_epochs, input_height, input_width, weight_decay, weighted_loss, index_start, dir_of_start_model, is_loss_soft_dice, n_batch, patches, augmentation, flip_aug, - blur_aug, scaling, binarization, - blur_k, scales, dir_train, data_is_provided, - scaling_bluring, scaling_binarization, rotation, - rotation_not_90, thetha, scaling_flip, continue_training, - flip_index, dir_eval, dir_output, pretraining, learning_rate): + blur_aug, padding_white, padding_black, scaling, degrading, + brightening, binarization, blur_k, scales, degrade_scales, + brightness, dir_train, data_is_provided, scaling_bluring, + scaling_brightness, scaling_binarization, rotation, rotation_not_90, + thetha, scaling_flip, continue_training, transformer_patchsize, + num_patches_xy, model_name, flip_index, dir_eval, dir_output, + pretraining, learning_rate): + + num_patches = num_patches_xy[0]*num_patches_xy[1] if data_is_provided: dir_train_flowing = os.path.join(dir_output, 'train') dir_eval_flowing = os.path.join(dir_output, 'eval') @@ -121,23 +135,28 @@ def run(n_classes, n_epochs, input_height, # set the gpu configuration configuration() + + imgs_list=np.array(os.listdir(dir_img)) + segs_list=np.array(os.listdir(dir_seg)) + + imgs_list_test=np.array(os.listdir(dir_img_val)) + segs_list_test=np.array(os.listdir(dir_seg_val)) # writing patches into a sub-folder in order to be flowed from directory. - provide_patches(dir_img, dir_seg, dir_flow_train_imgs, - dir_flow_train_labels, - input_height, input_width, blur_k, blur_aug, - flip_aug, binarization, scaling, scales, flip_index, - scaling_bluring, scaling_binarization, rotation, - rotation_not_90, thetha, scaling_flip, - augmentation=augmentation, patches=patches) - - provide_patches(dir_img_val, dir_seg_val, dir_flow_eval_imgs, - dir_flow_eval_labels, - input_height, input_width, blur_k, blur_aug, - flip_aug, binarization, scaling, scales, flip_index, - scaling_bluring, scaling_binarization, rotation, - rotation_not_90, thetha, scaling_flip, - augmentation=False, patches=patches) + provide_patches(imgs_list, segs_list, dir_img, dir_seg, dir_flow_train_imgs, + dir_flow_train_labels, input_height, input_width, blur_k, + blur_aug, padding_white, padding_black, flip_aug, binarization, + scaling, degrading, brightening, scales, degrade_scales, brightness, + flip_index, scaling_bluring, scaling_brightness, scaling_binarization, + rotation, rotation_not_90, thetha, scaling_flip, augmentation=augmentation, + patches=patches) + + provide_patches(imgs_list_test, segs_list_test, dir_img_val, dir_seg_val, + dir_flow_eval_imgs, dir_flow_eval_labels, input_height, input_width, + blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization, + scaling, degrading, brightening, scales, degrade_scales, brightness, + flip_index, scaling_bluring, scaling_brightness, scaling_binarization, + rotation, rotation_not_90, thetha, scaling_flip, augmentation=False, patches=patches) if weighted_loss: weights = np.zeros(n_classes) @@ -166,38 +185,50 @@ def run(n_classes, n_epochs, input_height, weights = weights / float(np.sum(weights)) if continue_training: - if is_loss_soft_dice: - model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss}) - if weighted_loss: - model = load_model(dir_of_start_model, compile=True, - custom_objects={'loss': weighted_categorical_crossentropy(weights)}) - if not is_loss_soft_dice and not weighted_loss: - model = load_model(dir_of_start_model, compile=True) + if model_name=='resnet50_unet': + if is_loss_soft_dice: + model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss}) + if weighted_loss: + model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)}) + if not is_loss_soft_dice and not weighted_loss: + model = load_model(dir_of_start_model , compile=True) + elif model_name=='hybrid_transformer_cnn': + if is_loss_soft_dice: + model = load_model(dir_of_start_model, compile=True, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches,'soft_dice_loss': soft_dice_loss}) + if weighted_loss: + model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)}) + if not is_loss_soft_dice and not weighted_loss: + model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches}) else: - # get our model. index_start = 0 - model = resnet50_unet(n_classes, input_height, input_width, weight_decay, pretraining) - - # if you want to see the model structure just uncomment model summary. - # model.summary() + if model_name=='resnet50_unet': + model = resnet50_unet(n_classes, input_height, input_width,weight_decay,pretraining) + elif model_name=='hybrid_transformer_cnn': + model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width,weight_decay,pretraining) + + #if you want to see the model structure just uncomment model summary. + #model.summary() + if not is_loss_soft_dice and not weighted_loss: model.compile(loss='categorical_crossentropy', optimizer=Adam(lr=learning_rate), metrics=['accuracy']) - if is_loss_soft_dice: + if is_loss_soft_dice: model.compile(loss=soft_dice_loss, optimizer=Adam(lr=learning_rate), metrics=['accuracy']) - if weighted_loss: model.compile(loss=weighted_categorical_crossentropy(weights), optimizer=Adam(lr=learning_rate), metrics=['accuracy']) - + # generating train and evaluation data train_gen = data_gen(dir_flow_train_imgs, dir_flow_train_labels, batch_size=n_batch, input_height=input_height, input_width=input_width, n_classes=n_classes) val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, batch_size=n_batch, input_height=input_height, input_width=input_width, n_classes=n_classes) - + + ##img_validation_patches = os.listdir(dir_flow_eval_imgs) + ##score_best=[] + ##score_best.append(0) for i in tqdm(range(index_start, n_epochs + index_start)): model.fit_generator( train_gen, @@ -205,9 +236,12 @@ def run(n_classes, n_epochs, input_height, validation_data=val_gen, validation_steps=1, epochs=1) - model.save(dir_output + '/' + 'model_' + str(i)) + model.save(dir_output+'/'+'model_'+str(i)) + + with open(dir_output+'/'+'model_'+str(i)+'/'+"config.json", "w") as fp: + json.dump(_config, fp) # encode dict into JSON - # os.system('rm -rf '+dir_train_flowing) - # os.system('rm -rf '+dir_eval_flowing) + #os.system('rm -rf '+dir_train_flowing) + #os.system('rm -rf '+dir_eval_flowing) - # model.save(dir_output+'/'+'model'+'.h5') + #model.save(dir_output+'/'+'model'+'.h5') diff --git a/train/utils.py b/train/utils.py index 7c65f18..c2786ec 100644 --- a/train/utils.py +++ b/train/utils.py @@ -9,6 +9,15 @@ from tqdm import tqdm import imutils import math +def do_brightening(img_in_dir, factor): + im = Image.open(img_in_dir) + enhancer = ImageEnhance.Brightness(im) + out_img = enhancer.enhance(factor) + out_img = out_img.convert('RGB') + opencv_img = np.array(out_img) + opencv_img = opencv_img[:,:,::-1].copy() + return opencv_img + def bluring(img_in, kind): if kind == 'gauss': @@ -138,11 +147,11 @@ def IoU(Yi, y_predi): FP = np.sum((Yi != c) & (y_predi == c)) FN = np.sum((Yi == c) & (y_predi != c)) IoU = TP / float(TP + FP + FN) - print("class {:02.0f}: #TP={:6.0f}, #FP={:6.0f}, #FN={:5.0f}, IoU={:4.3f}".format(c, TP, FP, FN, IoU)) + #print("class {:02.0f}: #TP={:6.0f}, #FP={:6.0f}, #FN={:5.0f}, IoU={:4.3f}".format(c, TP, FP, FN, IoU)) IoUs.append(IoU) mIoU = np.mean(IoUs) - print("_________________") - print("Mean IoU: {:4.3f}".format(mIoU)) + #print("_________________") + #print("Mean IoU: {:4.3f}".format(mIoU)) return mIoU @@ -241,124 +250,170 @@ def get_patches(dir_img_f, dir_seg_f, img, label, height, width, indexer): return indexer -def do_padding(img, label, height, width): - height_new = img.shape[0] - width_new = img.shape[1] +def do_padding_white(img): + img_org_h = img.shape[0] + img_org_w = img.shape[1] + + index_start_h = 4 + index_start_w = 4 + + img_padded = np.zeros((img.shape[0] + 2*index_start_h, img.shape[1]+ 2*index_start_w, img.shape[2])) + 255 + img_padded[index_start_h: index_start_h + img.shape[0], index_start_w: index_start_w + img.shape[1], :] = img[:, :, :] + + return img_padded.astype(float) + +def do_degrading(img, scale): + img_org_h = img.shape[0] + img_org_w = img.shape[1] + + img_res = resize_image(img, int(img_org_h * scale), int(img_org_w * scale)) + + return resize_image(img_res, img_org_h, img_org_w) + + +def do_padding_black(img): + img_org_h = img.shape[0] + img_org_w = img.shape[1] + + index_start_h = 4 + index_start_w = 4 + + img_padded = np.zeros((img.shape[0] + 2*index_start_h, img.shape[1] + 2*index_start_w, img.shape[2])) + img_padded[index_start_h: index_start_h + img.shape[0], index_start_w: index_start_w + img.shape[1], :] = img[:, :, :] + + return img_padded.astype(float) + + +def do_padding_label(img): + img_org_h = img.shape[0] + img_org_w = img.shape[1] + + index_start_h = 4 + index_start_w = 4 + + img_padded = np.zeros((img.shape[0] + 2*index_start_h, img.shape[1] + 2*index_start_w, img.shape[2])) + img_padded[index_start_h: index_start_h + img.shape[0], index_start_w: index_start_w + img.shape[1], :] = img[:, :, :] + + return img_padded.astype(np.int16) + +def do_padding(img, label, height, width): + height_new=img.shape[0] + width_new=img.shape[1] + h_start = 0 w_start = 0 - + if img.shape[0] < height: h_start = int(abs(height - img.shape[0]) / 2.) height_new = height - + if img.shape[1] < width: w_start = int(abs(width - img.shape[1]) / 2.) width_new = width - + img_new = np.ones((height_new, width_new, img.shape[2])).astype(float) * 255 label_new = np.zeros((height_new, width_new, label.shape[2])).astype(float) - + img_new[h_start:h_start + img.shape[0], w_start:w_start + img.shape[1], :] = np.copy(img[:, :, :]) label_new[h_start:h_start + label.shape[0], w_start:w_start + label.shape[1], :] = np.copy(label[:, :, :]) - - return img_new, label_new + + return img_new,label_new def get_patches_num_scale(dir_img_f, dir_seg_f, img, label, height, width, indexer, n_patches, scaler): if img.shape[0] < height or img.shape[1] < width: img, label = do_padding(img, label, height, width) - + img_h = img.shape[0] img_w = img.shape[1] - + height_scale = int(height * scaler) width_scale = int(width * scaler) - + + nxf = img_w / float(width_scale) nyf = img_h / float(height_scale) - + if nxf > int(nxf): nxf = int(nxf) + 1 if nyf > int(nyf): nyf = int(nyf) + 1 - + nxf = int(nxf) nyf = int(nyf) - + for i in range(nxf): for j in range(nyf): index_x_d = i * width_scale index_x_u = (i + 1) * width_scale - + index_y_d = j * height_scale index_y_u = (j + 1) * height_scale - + if index_x_u > img_w: index_x_u = img_w index_x_d = img_w - width_scale if index_y_u > img_h: index_y_u = img_h index_y_d = img_h - height_scale - + + img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :] label_patch = label[index_y_d:index_y_u, index_x_d:index_x_u, :] - + img_patch = resize_image(img_patch, height, width) label_patch = resize_image(label_patch, height, width) - + cv2.imwrite(dir_img_f + '/img_' + str(indexer) + '.png', img_patch) cv2.imwrite(dir_seg_f + '/img_' + str(indexer) + '.png', label_patch) indexer += 1 - + return indexer def get_patches_num_scale_new(dir_img_f, dir_seg_f, img, label, height, width, indexer, scaler): img = resize_image(img, int(img.shape[0] * scaler), int(img.shape[1] * scaler)) label = resize_image(label, int(label.shape[0] * scaler), int(label.shape[1] * scaler)) - + if img.shape[0] < height or img.shape[1] < width: img, label = do_padding(img, label, height, width) - + img_h = img.shape[0] img_w = img.shape[1] - + height_scale = int(height * 1) width_scale = int(width * 1) - + nxf = img_w / float(width_scale) nyf = img_h / float(height_scale) - + if nxf > int(nxf): nxf = int(nxf) + 1 if nyf > int(nyf): nyf = int(nyf) + 1 - + nxf = int(nxf) nyf = int(nyf) - + for i in range(nxf): for j in range(nyf): index_x_d = i * width_scale index_x_u = (i + 1) * width_scale - + index_y_d = j * height_scale index_y_u = (j + 1) * height_scale - + if index_x_u > img_w: index_x_u = img_w index_x_d = img_w - width_scale if index_y_u > img_h: index_y_u = img_h index_y_d = img_h - height_scale - + img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :] label_patch = label[index_y_d:index_y_u, index_x_d:index_x_u, :] - - # img_patch=resize_image(img_patch,height,width) - # label_patch=resize_image(label_patch,height,width) - + cv2.imwrite(dir_img_f + '/img_' + str(indexer) + '.png', img_patch) cv2.imwrite(dir_seg_f + '/img_' + str(indexer) + '.png', label_patch) indexer += 1 @@ -366,78 +421,65 @@ def get_patches_num_scale_new(dir_img_f, dir_seg_f, img, label, height, width, i return indexer -def provide_patches(dir_img, dir_seg, dir_flow_train_imgs, - dir_flow_train_labels, - input_height, input_width, blur_k, blur_aug, - flip_aug, binarization, scaling, scales, flip_index, - scaling_bluring, scaling_binarization, rotation, - rotation_not_90, thetha, scaling_flip, - augmentation=False, patches=False): - imgs_cv_train = np.array(os.listdir(dir_img)) - segs_cv_train = np.array(os.listdir(dir_seg)) - +def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow_train_imgs, + dir_flow_train_labels, input_height, input_width, blur_k, blur_aug, + padding_white, padding_black, flip_aug, binarization, scaling, degrading, + brightening, scales, degrade_scales, brightness, flip_index, + scaling_bluring, scaling_brightness, scaling_binarization, rotation, + rotation_not_90, thetha, scaling_flip, augmentation=False, patches=False): + indexer = 0 - for im, seg_i in tqdm(zip(imgs_cv_train, segs_cv_train)): + for im, seg_i in tqdm(zip(imgs_list_train, segs_list_train)): img_name = im.split('.')[0] if not patches: - cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', - resize_image(cv2.imread(dir_img + '/' + im), input_height, input_width)) - cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', - resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width)) + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(cv2.imread(dir_img + '/' + im), input_height, input_width)) + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width)) indexer += 1 - + if augmentation: if flip_aug: for f_i in flip_index: cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', - resize_image(cv2.flip(cv2.imread(dir_img + '/' + im), f_i), input_height, - input_width)) - + resize_image(cv2.flip(cv2.imread(dir_img+'/'+im),f_i),input_height,input_width) ) + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', - resize_image(cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), f_i), - input_height, input_width)) + resize_image(cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), f_i), input_height, input_width)) indexer += 1 - - if blur_aug: + + if blur_aug: for blur_i in blur_k: cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', - (resize_image(bluring(cv2.imread(dir_img + '/' + im), blur_i), input_height, - input_width))) - + (resize_image(bluring(cv2.imread(dir_img + '/' + im), blur_i), input_height, input_width))) + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', - resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, - input_width)) + resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width)) indexer += 1 - + if binarization: cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(otsu_copy(cv2.imread(dir_img + '/' + im)), input_height, input_width)) - + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width)) indexer += 1 - + + if patches: - indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, cv2.imread(dir_img + '/' + im), cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width, indexer=indexer) - + if augmentation: - if rotation: indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, - rotation_90(cv2.imread(dir_img + '/' + im)), - rotation_90(cv2.imread(dir_seg + '/' + img_name + '.png')), - input_height, input_width, indexer=indexer) - + rotation_90(cv2.imread(dir_img + '/' + im)), + rotation_90(cv2.imread(dir_seg + '/' + img_name + '.png')), + input_height, input_width, indexer=indexer) + if rotation_not_90: - for thetha_i in thetha: - img_max_rotated, label_max_rotated = rotation_not_90_func(cv2.imread(dir_img + '/' + im), - cv2.imread( - dir_seg + '/' + img_name + '.png'), - thetha_i) + img_max_rotated, label_max_rotated = rotation_not_90_func(cv2.imread(dir_img + '/'+im), + cv2.imread(dir_seg + '/'+img_name + '.png'), thetha_i) indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, img_max_rotated, label_max_rotated, @@ -448,47 +490,84 @@ def provide_patches(dir_img, dir_seg, dir_flow_train_imgs, cv2.flip(cv2.imread(dir_img + '/' + im), f_i), cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), f_i), input_height, input_width, indexer=indexer) - if blur_aug: + if blur_aug: for blur_i in blur_k: indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, bluring(cv2.imread(dir_img + '/' + im), blur_i), cv2.imread(dir_seg + '/' + img_name + '.png'), - input_height, input_width, indexer=indexer) - - if scaling: + input_height, input_width, indexer=indexer) + if padding_black: + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + do_padding_black(cv2.imread(dir_img + '/' + im)), + do_padding_label(cv2.imread(dir_seg + '/' + img_name + '.png')), + input_height, input_width, indexer=indexer) + + if padding_white: + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + do_padding_white(cv2.imread(dir_img + '/'+im)), + do_padding_label(cv2.imread(dir_seg + '/' + img_name + '.png')), + input_height, input_width, indexer=indexer) + + if brightening: + for factor in brightness: + try: + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + do_brightening(dir_img + '/' +im, factor), + cv2.imread(dir_seg + '/' + img_name + '.png'), + input_height, input_width, indexer=indexer) + except: + pass + if scaling: for sc_ind in scales: indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels, - cv2.imread(dir_img + '/' + im), + cv2.imread(dir_img + '/' + im) , cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width, indexer=indexer, scaler=sc_ind) + + if degrading: + for degrade_scale_ind in degrade_scales: + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + do_degrading(cv2.imread(dir_img + '/' + im), degrade_scale_ind), + cv2.imread(dir_seg + '/' + img_name + '.png'), + input_height, input_width, indexer=indexer) + if binarization: indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, otsu_copy(cv2.imread(dir_img + '/' + im)), cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width, indexer=indexer) - if scaling_bluring: + if scaling_brightness: + for sc_ind in scales: + for factor in brightness: + try: + indexer = get_patches_num_scale_new(dir_flow_train_imgs, + dir_flow_train_labels, + do_brightening(dir_img + '/' + im, factor) + ,cv2.imread(dir_seg + '/' + img_name + '.png') + ,input_height, input_width, indexer=indexer, scaler=sc_ind) + except: + pass + + if scaling_bluring: for sc_ind in scales: for blur_i in blur_k: indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels, bluring(cv2.imread(dir_img + '/' + im), blur_i), cv2.imread(dir_seg + '/' + img_name + '.png'), - input_height, input_width, indexer=indexer, - scaler=sc_ind) + input_height, input_width, indexer=indexer, scaler=sc_ind) - if scaling_binarization: + if scaling_binarization: for sc_ind in scales: indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels, otsu_copy(cv2.imread(dir_img + '/' + im)), cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width, indexer=indexer, scaler=sc_ind) - - if scaling_flip: + + if scaling_flip: for sc_ind in scales: for f_i in flip_index: indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels, - cv2.flip(cv2.imread(dir_img + '/' + im), f_i), - cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), - f_i), - input_height, input_width, indexer=indexer, - scaler=sc_ind) + cv2.flip( cv2.imread(dir_img + '/' + im), f_i), + cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), f_i), + input_height, input_width, indexer=indexer, scaler=sc_ind) From ca63c097c3c30b58513d708f476139c590ac2d94 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Mon, 29 Apr 2024 20:59:36 +0200 Subject: [PATCH 040/123] integrating first working classification training model --- train/config_params.json | 20 ++- train/models.py | 69 +++++++- train/requirements.txt | 1 + train/train.py | 374 ++++++++++++++++++++++++--------------- train/utils.py | 113 ++++++++++++ 5 files changed, 419 insertions(+), 158 deletions(-) diff --git a/train/config_params.json b/train/config_params.json index bd47a52..43ad1bc 100644 --- a/train/config_params.json +++ b/train/config_params.json @@ -1,13 +1,15 @@ { - "model_name" : "hybrid_transformer_cnn", + "model_name" : "resnet50_unet", + "task": "classification", "n_classes" : 2, - "n_epochs" : 2, - "input_height" : 448, - "input_width" : 448, + "n_epochs" : 7, + "input_height" : 224, + "input_width" : 224, "weight_decay" : 1e-6, - "n_batch" : 2, + "n_batch" : 6, "learning_rate": 1e-4, - "patches" : true, + "f1_threshold_classification": 0.8, + "patches" : false, "pretraining" : true, "augmentation" : false, "flip_aug" : false, @@ -33,7 +35,7 @@ "weighted_loss": false, "is_loss_soft_dice": false, "data_is_provided": false, - "dir_train": "/train", - "dir_eval": "/eval", - "dir_output": "/out" + "dir_train": "/home/vahid/Downloads/image_classification_data/train", + "dir_eval": "/home/vahid/Downloads/image_classification_data/eval", + "dir_output": "/home/vahid/Downloads/image_classification_data/output" } diff --git a/train/models.py b/train/models.py index f7a7ad8..a6de1ef 100644 --- a/train/models.py +++ b/train/models.py @@ -400,7 +400,7 @@ def vit_resnet50_unet(n_classes,patch_size, num_patches, input_height=224,input_ f5 = x if pretraining: - model = keras.Model(inputs, x).load_weights(resnet50_Weights_path) + model = Model(inputs, x).load_weights(resnet50_Weights_path) num_patches = x.shape[1]*x.shape[2] patches = Patches(patch_size)(x) @@ -468,6 +468,71 @@ def vit_resnet50_unet(n_classes,patch_size, num_patches, input_height=224,input_ o = (BatchNormalization(axis=bn_axis))(o) o = (Activation('softmax'))(o) - model = keras.Model(inputs=inputs, outputs=o) + model = Model(inputs=inputs, outputs=o) + return model + +def resnet50_classifier(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False): + include_top=True + assert input_height%32 == 0 + assert input_width%32 == 0 + + + img_input = Input(shape=(input_height,input_width , 3 )) + + if IMAGE_ORDERING == 'channels_last': + bn_axis = 3 + else: + bn_axis = 1 + + x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input) + x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x) + f1 = x + + x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) + x = Activation('relu')(x) + x = MaxPooling2D((3, 3) , data_format=IMAGE_ORDERING , strides=(2, 2))(x) + + + x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) + x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') + x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') + f2 = one_side_pad(x ) + + + x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') + f3 = x + + x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') + f4 = x + + x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') + f5 = x + + if pretraining: + Model(img_input, x).load_weights(resnet50_Weights_path) + + x = AveragePooling2D((7, 7), name='avg_pool')(x) + x = Flatten()(x) + + ## + x = Dense(256, activation='relu', name='fc512')(x) + x=Dropout(0.2)(x) + ## + x = Dense(n_classes, activation='softmax', name='fc1000')(x) + model = Model(img_input, x) + + + + return model diff --git a/train/requirements.txt b/train/requirements.txt index 20b6a32..3e56438 100644 --- a/train/requirements.txt +++ b/train/requirements.txt @@ -6,3 +6,4 @@ tqdm imutils numpy scipy +scikit-learn diff --git a/train/train.py b/train/train.py index 6e6a172..efcd3ac 100644 --- a/train/train.py +++ b/train/train.py @@ -11,6 +11,7 @@ from metrics import * from tensorflow.keras.models import load_model from tqdm import tqdm import json +from sklearn.metrics import f1_score def configuration(): @@ -73,6 +74,8 @@ def config_params(): is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false. weighted_loss = False # Use weighted categorical cross entropy as loss fucntion. When set to true, "is_loss_soft_dice" must be false. data_is_provided = False # Only set this to true when you have already provided the input data and the train and eval data are in "dir_output". + task = "segmentation" # This parameter defines task of model which can be segmentation, enhancement or classification. + f1_threshold_classification = None # This threshold is used to consider models with an evaluation f1 scores bigger than it. The selected model weights undergo a weights ensembling. And avreage ensembled model will be written to output. @ex.automain @@ -86,162 +89,239 @@ def run(_config, n_classes, n_epochs, input_height, scaling_brightness, scaling_binarization, rotation, rotation_not_90, thetha, scaling_flip, continue_training, transformer_patchsize, num_patches_xy, model_name, flip_index, dir_eval, dir_output, - pretraining, learning_rate): + pretraining, learning_rate, task, f1_threshold_classification): - num_patches = num_patches_xy[0]*num_patches_xy[1] - if data_is_provided: - dir_train_flowing = os.path.join(dir_output, 'train') - dir_eval_flowing = os.path.join(dir_output, 'eval') - - dir_flow_train_imgs = os.path.join(dir_train_flowing, 'images') - dir_flow_train_labels = os.path.join(dir_train_flowing, 'labels') - - dir_flow_eval_imgs = os.path.join(dir_eval_flowing, 'images') - dir_flow_eval_labels = os.path.join(dir_eval_flowing, 'labels') - - configuration() - - else: - dir_img, dir_seg = get_dirs_or_files(dir_train) - dir_img_val, dir_seg_val = get_dirs_or_files(dir_eval) - - # make first a directory in output for both training and evaluations in order to flow data from these directories. - dir_train_flowing = os.path.join(dir_output, 'train') - dir_eval_flowing = os.path.join(dir_output, 'eval') - - dir_flow_train_imgs = os.path.join(dir_train_flowing, 'images/') - dir_flow_train_labels = os.path.join(dir_train_flowing, 'labels/') - - dir_flow_eval_imgs = os.path.join(dir_eval_flowing, 'images/') - dir_flow_eval_labels = os.path.join(dir_eval_flowing, 'labels/') - - if os.path.isdir(dir_train_flowing): - os.system('rm -rf ' + dir_train_flowing) - os.makedirs(dir_train_flowing) - else: - os.makedirs(dir_train_flowing) - - if os.path.isdir(dir_eval_flowing): - os.system('rm -rf ' + dir_eval_flowing) - os.makedirs(dir_eval_flowing) - else: - os.makedirs(dir_eval_flowing) - - os.mkdir(dir_flow_train_imgs) - os.mkdir(dir_flow_train_labels) - - os.mkdir(dir_flow_eval_imgs) - os.mkdir(dir_flow_eval_labels) - - # set the gpu configuration - configuration() + if task == "segmentation": - imgs_list=np.array(os.listdir(dir_img)) - segs_list=np.array(os.listdir(dir_seg)) - - imgs_list_test=np.array(os.listdir(dir_img_val)) - segs_list_test=np.array(os.listdir(dir_seg_val)) - - # writing patches into a sub-folder in order to be flowed from directory. - provide_patches(imgs_list, segs_list, dir_img, dir_seg, dir_flow_train_imgs, - dir_flow_train_labels, input_height, input_width, blur_k, - blur_aug, padding_white, padding_black, flip_aug, binarization, - scaling, degrading, brightening, scales, degrade_scales, brightness, - flip_index, scaling_bluring, scaling_brightness, scaling_binarization, - rotation, rotation_not_90, thetha, scaling_flip, augmentation=augmentation, - patches=patches) - - provide_patches(imgs_list_test, segs_list_test, dir_img_val, dir_seg_val, - dir_flow_eval_imgs, dir_flow_eval_labels, input_height, input_width, - blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization, - scaling, degrading, brightening, scales, degrade_scales, brightness, - flip_index, scaling_bluring, scaling_brightness, scaling_binarization, - rotation, rotation_not_90, thetha, scaling_flip, augmentation=False, patches=patches) - - if weighted_loss: - weights = np.zeros(n_classes) + num_patches = num_patches_xy[0]*num_patches_xy[1] if data_is_provided: - for obj in os.listdir(dir_flow_train_labels): - try: - label_obj = cv2.imread(dir_flow_train_labels + '/' + obj) - label_obj_one_hot = get_one_hot(label_obj, label_obj.shape[0], label_obj.shape[1], n_classes) - weights += (label_obj_one_hot.sum(axis=0)).sum(axis=0) - except: - pass + dir_train_flowing = os.path.join(dir_output, 'train') + dir_eval_flowing = os.path.join(dir_output, 'eval') + + dir_flow_train_imgs = os.path.join(dir_train_flowing, 'images') + dir_flow_train_labels = os.path.join(dir_train_flowing, 'labels') + + dir_flow_eval_imgs = os.path.join(dir_eval_flowing, 'images') + dir_flow_eval_labels = os.path.join(dir_eval_flowing, 'labels') + + configuration() + else: + dir_img, dir_seg = get_dirs_or_files(dir_train) + dir_img_val, dir_seg_val = get_dirs_or_files(dir_eval) - for obj in os.listdir(dir_seg): - try: - label_obj = cv2.imread(dir_seg + '/' + obj) - label_obj_one_hot = get_one_hot(label_obj, label_obj.shape[0], label_obj.shape[1], n_classes) - weights += (label_obj_one_hot.sum(axis=0)).sum(axis=0) - except: - pass + # make first a directory in output for both training and evaluations in order to flow data from these directories. + dir_train_flowing = os.path.join(dir_output, 'train') + dir_eval_flowing = os.path.join(dir_output, 'eval') - weights = 1.00 / weights + dir_flow_train_imgs = os.path.join(dir_train_flowing, 'images/') + dir_flow_train_labels = os.path.join(dir_train_flowing, 'labels/') - weights = weights / float(np.sum(weights)) - weights = weights / float(np.min(weights)) - weights = weights / float(np.sum(weights)) + dir_flow_eval_imgs = os.path.join(dir_eval_flowing, 'images/') + dir_flow_eval_labels = os.path.join(dir_eval_flowing, 'labels/') - if continue_training: - if model_name=='resnet50_unet': - if is_loss_soft_dice: - model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss}) - if weighted_loss: - model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)}) - if not is_loss_soft_dice and not weighted_loss: - model = load_model(dir_of_start_model , compile=True) - elif model_name=='hybrid_transformer_cnn': - if is_loss_soft_dice: - model = load_model(dir_of_start_model, compile=True, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches,'soft_dice_loss': soft_dice_loss}) - if weighted_loss: - model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)}) - if not is_loss_soft_dice and not weighted_loss: - model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches}) - else: - index_start = 0 - if model_name=='resnet50_unet': - model = resnet50_unet(n_classes, input_height, input_width,weight_decay,pretraining) - elif model_name=='hybrid_transformer_cnn': - model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width,weight_decay,pretraining) - - #if you want to see the model structure just uncomment model summary. - #model.summary() - + if os.path.isdir(dir_train_flowing): + os.system('rm -rf ' + dir_train_flowing) + os.makedirs(dir_train_flowing) + else: + os.makedirs(dir_train_flowing) - if not is_loss_soft_dice and not weighted_loss: + if os.path.isdir(dir_eval_flowing): + os.system('rm -rf ' + dir_eval_flowing) + os.makedirs(dir_eval_flowing) + else: + os.makedirs(dir_eval_flowing) + + os.mkdir(dir_flow_train_imgs) + os.mkdir(dir_flow_train_labels) + + os.mkdir(dir_flow_eval_imgs) + os.mkdir(dir_flow_eval_labels) + + # set the gpu configuration + configuration() + + imgs_list=np.array(os.listdir(dir_img)) + segs_list=np.array(os.listdir(dir_seg)) + + imgs_list_test=np.array(os.listdir(dir_img_val)) + segs_list_test=np.array(os.listdir(dir_seg_val)) + + # writing patches into a sub-folder in order to be flowed from directory. + provide_patches(imgs_list, segs_list, dir_img, dir_seg, dir_flow_train_imgs, + dir_flow_train_labels, input_height, input_width, blur_k, + blur_aug, padding_white, padding_black, flip_aug, binarization, + scaling, degrading, brightening, scales, degrade_scales, brightness, + flip_index, scaling_bluring, scaling_brightness, scaling_binarization, + rotation, rotation_not_90, thetha, scaling_flip, augmentation=augmentation, + patches=patches) + + provide_patches(imgs_list_test, segs_list_test, dir_img_val, dir_seg_val, + dir_flow_eval_imgs, dir_flow_eval_labels, input_height, input_width, + blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization, + scaling, degrading, brightening, scales, degrade_scales, brightness, + flip_index, scaling_bluring, scaling_brightness, scaling_binarization, + rotation, rotation_not_90, thetha, scaling_flip, augmentation=False, patches=patches) + + if weighted_loss: + weights = np.zeros(n_classes) + if data_is_provided: + for obj in os.listdir(dir_flow_train_labels): + try: + label_obj = cv2.imread(dir_flow_train_labels + '/' + obj) + label_obj_one_hot = get_one_hot(label_obj, label_obj.shape[0], label_obj.shape[1], n_classes) + weights += (label_obj_one_hot.sum(axis=0)).sum(axis=0) + except: + pass + else: + + for obj in os.listdir(dir_seg): + try: + label_obj = cv2.imread(dir_seg + '/' + obj) + label_obj_one_hot = get_one_hot(label_obj, label_obj.shape[0], label_obj.shape[1], n_classes) + weights += (label_obj_one_hot.sum(axis=0)).sum(axis=0) + except: + pass + + weights = 1.00 / weights + + weights = weights / float(np.sum(weights)) + weights = weights / float(np.min(weights)) + weights = weights / float(np.sum(weights)) + + if continue_training: + if model_name=='resnet50_unet': + if is_loss_soft_dice: + model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss}) + if weighted_loss: + model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)}) + if not is_loss_soft_dice and not weighted_loss: + model = load_model(dir_of_start_model , compile=True) + elif model_name=='hybrid_transformer_cnn': + if is_loss_soft_dice: + model = load_model(dir_of_start_model, compile=True, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches,'soft_dice_loss': soft_dice_loss}) + if weighted_loss: + model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)}) + if not is_loss_soft_dice and not weighted_loss: + model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches}) + else: + index_start = 0 + if model_name=='resnet50_unet': + model = resnet50_unet(n_classes, input_height, input_width,weight_decay,pretraining) + elif model_name=='hybrid_transformer_cnn': + model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width,weight_decay,pretraining) + + #if you want to see the model structure just uncomment model summary. + #model.summary() + + + if not is_loss_soft_dice and not weighted_loss: + model.compile(loss='categorical_crossentropy', + optimizer=Adam(lr=learning_rate), metrics=['accuracy']) + if is_loss_soft_dice: + model.compile(loss=soft_dice_loss, + optimizer=Adam(lr=learning_rate), metrics=['accuracy']) + if weighted_loss: + model.compile(loss=weighted_categorical_crossentropy(weights), + optimizer=Adam(lr=learning_rate), metrics=['accuracy']) + + # generating train and evaluation data + train_gen = data_gen(dir_flow_train_imgs, dir_flow_train_labels, batch_size=n_batch, + input_height=input_height, input_width=input_width, n_classes=n_classes) + val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, batch_size=n_batch, + input_height=input_height, input_width=input_width, n_classes=n_classes) + + ##img_validation_patches = os.listdir(dir_flow_eval_imgs) + ##score_best=[] + ##score_best.append(0) + for i in tqdm(range(index_start, n_epochs + index_start)): + model.fit_generator( + train_gen, + steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1, + validation_data=val_gen, + validation_steps=1, + epochs=1) + model.save(dir_output+'/'+'model_'+str(i)) + + with open(dir_output+'/'+'model_'+str(i)+'/'+"config.json", "w") as fp: + json.dump(_config, fp) # encode dict into JSON + + #os.system('rm -rf '+dir_train_flowing) + #os.system('rm -rf '+dir_eval_flowing) + + #model.save(dir_output+'/'+'model'+'.h5') + elif task=='classification': + configuration() + model = resnet50_classifier(n_classes, input_height, input_width,weight_decay,pretraining) + + opt_adam = Adam(learning_rate=0.001) model.compile(loss='categorical_crossentropy', - optimizer=Adam(lr=learning_rate), metrics=['accuracy']) - if is_loss_soft_dice: - model.compile(loss=soft_dice_loss, - optimizer=Adam(lr=learning_rate), metrics=['accuracy']) - if weighted_loss: - model.compile(loss=weighted_categorical_crossentropy(weights), - optimizer=Adam(lr=learning_rate), metrics=['accuracy']) - - # generating train and evaluation data - train_gen = data_gen(dir_flow_train_imgs, dir_flow_train_labels, batch_size=n_batch, - input_height=input_height, input_width=input_width, n_classes=n_classes) - val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, batch_size=n_batch, - input_height=input_height, input_width=input_width, n_classes=n_classes) - - ##img_validation_patches = os.listdir(dir_flow_eval_imgs) - ##score_best=[] - ##score_best.append(0) - for i in tqdm(range(index_start, n_epochs + index_start)): - model.fit_generator( - train_gen, - steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1, - validation_data=val_gen, - validation_steps=1, - epochs=1) - model.save(dir_output+'/'+'model_'+str(i)) - - with open(dir_output+'/'+'model_'+str(i)+'/'+"config.json", "w") as fp: - json.dump(_config, fp) # encode dict into JSON + optimizer = opt_adam,metrics=['accuracy']) - #os.system('rm -rf '+dir_train_flowing) - #os.system('rm -rf '+dir_eval_flowing) - #model.save(dir_output+'/'+'model'+'.h5') + testX, testY = generate_data_from_folder_evaluation(dir_eval, input_height, input_width, n_classes) + + #print(testY.shape, testY) + + y_tot=np.zeros((testX.shape[0],n_classes)) + indexer=0 + + score_best=[] + score_best.append(0) + + num_rows = return_number_of_total_training_data(dir_train) + + weights=[] + + for i in range(n_epochs): + #history = model.fit(trainX, trainY, epochs=1, batch_size=n_batch, validation_data=(testX, testY), verbose=2)#,class_weight=weights) + history = model.fit( generate_data_from_folder_training(dir_train, n_batch , input_height, input_width, n_classes), steps_per_epoch=num_rows / n_batch, verbose=0)#,class_weight=weights) + + y_pr_class = [] + for jj in range(testY.shape[0]): + y_pr=model.predict(testX[jj,:,:,:].reshape(1,input_height,input_width,3), verbose=0) + y_pr_ind= np.argmax(y_pr,axis=1) + #print(y_pr_ind, 'y_pr_ind') + y_pr_class.append(y_pr_ind) + + + y_pr_class = np.array(y_pr_class) + #model.save('./models_save/model_'+str(i)+'.h5') + #y_pr_class=np.argmax(y_pr,axis=1) + f1score=f1_score(np.argmax(testY,axis=1), y_pr_class, average='macro') + + print(i,f1score) + + if f1score>score_best[0]: + score_best[0]=f1score + model.save(os.path.join(dir_output,'model_best')) + + + ##best_model=keras.models.clone_model(model) + ##best_model.build() + ##best_model.set_weights(model.get_weights()) + if f1score > f1_threshold_classification: + weights.append(model.get_weights() ) + y_tot=y_tot+y_pr + + indexer+=1 + y_tot=y_tot/float(indexer) + + + new_weights=list() + + for weights_list_tuple in zip(*weights): + new_weights.append( [np.array(weights_).mean(axis=0) for weights_ in zip(*weights_list_tuple)] ) + + new_weights = [np.array(x) for x in new_weights] + + model_weight_averaged=tf.keras.models.clone_model(model) + + model_weight_averaged.set_weights(new_weights) + + #y_tot_end=np.argmax(y_tot,axis=1) + #print(f1_score(np.argmax(testY,axis=1), y_tot_end, average='macro')) + + ##best_model.save('model_taza.h5') + model_weight_averaged.save(os.path.join(dir_output,'model_ens_avg')) + diff --git a/train/utils.py b/train/utils.py index c2786ec..af3c5f8 100644 --- a/train/utils.py +++ b/train/utils.py @@ -8,6 +8,119 @@ import random from tqdm import tqdm import imutils import math +from tensorflow.keras.utils import to_categorical + + +def return_number_of_total_training_data(path_classes): + sub_classes = os.listdir(path_classes) + n_tot = 0 + for sub_c in sub_classes: + sub_files = os.listdir(os.path.join(path_classes,sub_c)) + n_tot = n_tot + len(sub_files) + return n_tot + + + +def generate_data_from_folder_evaluation(path_classes, height, width, n_classes): + sub_classes = os.listdir(path_classes) + #n_classes = len(sub_classes) + all_imgs = [] + labels = [] + dicts =dict() + indexer= 0 + for sub_c in sub_classes: + sub_files = os.listdir(os.path.join(path_classes,sub_c )) + sub_files = [os.path.join(path_classes,sub_c )+'/' + x for x in sub_files] + #print( os.listdir(os.path.join(path_classes,sub_c )) ) + all_imgs = all_imgs + sub_files + sub_labels = list( np.zeros( len(sub_files) ) +indexer ) + + #print( len(sub_labels) ) + labels = labels + sub_labels + dicts[sub_c] = indexer + indexer +=1 + + + categories = to_categorical(range(n_classes)).astype(np.int16)#[ [1 , 0, 0 , 0 , 0 , 0] , [0 , 1, 0 , 0 , 0 , 0] , [0 , 0, 1 , 0 , 0 , 0] , [0 , 0, 0 , 1 , 0 , 0] , [0 , 0, 0 , 0 , 1 , 0] , [0 , 0, 0 , 0 , 0 , 1] ] + ret_x= np.zeros((len(labels), height,width, 3)).astype(np.int16) + ret_y= np.zeros((len(labels), n_classes)).astype(np.int16) + + #print(all_imgs) + for i in range(len(all_imgs)): + row = all_imgs[i] + #####img = cv2.imread(row, 0) + #####img= resize_image (img, height, width) + #####img = img.astype(np.uint16) + #####ret_x[i, :,:,0] = img[:,:] + #####ret_x[i, :,:,1] = img[:,:] + #####ret_x[i, :,:,2] = img[:,:] + + img = cv2.imread(row) + img= resize_image (img, height, width) + img = img.astype(np.uint16) + ret_x[i, :,:] = img[:,:,:] + + ret_y[i, :] = categories[ int( labels[i] ) ][:] + + return ret_x/255., ret_y + +def generate_data_from_folder_training(path_classes, batchsize, height, width, n_classes): + sub_classes = os.listdir(path_classes) + n_classes = len(sub_classes) + + all_imgs = [] + labels = [] + dicts =dict() + indexer= 0 + for sub_c in sub_classes: + sub_files = os.listdir(os.path.join(path_classes,sub_c )) + sub_files = [os.path.join(path_classes,sub_c )+'/' + x for x in sub_files] + #print( os.listdir(os.path.join(path_classes,sub_c )) ) + all_imgs = all_imgs + sub_files + sub_labels = list( np.zeros( len(sub_files) ) +indexer ) + + #print( len(sub_labels) ) + labels = labels + sub_labels + dicts[sub_c] = indexer + indexer +=1 + + ids = np.array(range(len(labels))) + random.shuffle(ids) + + shuffled_labels = np.array(labels)[ids] + shuffled_files = np.array(all_imgs)[ids] + categories = to_categorical(range(n_classes)).astype(np.int16)#[ [1 , 0, 0 , 0 , 0 , 0] , [0 , 1, 0 , 0 , 0 , 0] , [0 , 0, 1 , 0 , 0 , 0] , [0 , 0, 0 , 1 , 0 , 0] , [0 , 0, 0 , 0 , 1 , 0] , [0 , 0, 0 , 0 , 0 , 1] ] + ret_x= np.zeros((batchsize, height,width, 3)).astype(np.int16) + ret_y= np.zeros((batchsize, n_classes)).astype(np.int16) + batchcount = 0 + while True: + for i in range(len(shuffled_files)): + row = shuffled_files[i] + #print(row) + ###img = cv2.imread(row, 0) + ###img= resize_image (img, height, width) + ###img = img.astype(np.uint16) + ###ret_x[batchcount, :,:,0] = img[:,:] + ###ret_x[batchcount, :,:,1] = img[:,:] + ###ret_x[batchcount, :,:,2] = img[:,:] + + img = cv2.imread(row) + img= resize_image (img, height, width) + img = img.astype(np.uint16) + ret_x[batchcount, :,:,:] = img[:,:,:] + + #print(int(shuffled_labels[i]) ) + #print( categories[int(shuffled_labels[i])] ) + ret_y[batchcount, :] = categories[ int( shuffled_labels[i] ) ][:] + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield (ret_x, ret_y) + ret_x= np.zeros((batchsize, height,width, 3)).astype(np.int16) + ret_y= np.zeros((batchsize, n_classes)).astype(np.int16) + batchcount = 0 def do_brightening(img_in_dir, factor): im = Image.open(img_in_dir) From c989f7ac6111314a394700e833abe351f5daae43 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Mon, 6 May 2024 18:31:48 +0200 Subject: [PATCH 041/123] adding enhancement training --- train/config_params.json | 20 +++++----- train/gt_for_enhancement_creator.py | 31 +++++++++++++++ train/models.py | 27 ++++++++----- train/train.py | 47 ++++++++++++---------- train/utils.py | 62 ++++++++++++++++------------- 5 files changed, 119 insertions(+), 68 deletions(-) create mode 100644 train/gt_for_enhancement_creator.py diff --git a/train/config_params.json b/train/config_params.json index 43ad1bc..1c7a940 100644 --- a/train/config_params.json +++ b/train/config_params.json @@ -1,15 +1,15 @@ { "model_name" : "resnet50_unet", - "task": "classification", - "n_classes" : 2, - "n_epochs" : 7, - "input_height" : 224, - "input_width" : 224, + "task": "enhancement", + "n_classes" : 3, + "n_epochs" : 3, + "input_height" : 448, + "input_width" : 448, "weight_decay" : 1e-6, - "n_batch" : 6, + "n_batch" : 3, "learning_rate": 1e-4, "f1_threshold_classification": 0.8, - "patches" : false, + "patches" : true, "pretraining" : true, "augmentation" : false, "flip_aug" : false, @@ -35,7 +35,7 @@ "weighted_loss": false, "is_loss_soft_dice": false, "data_is_provided": false, - "dir_train": "/home/vahid/Downloads/image_classification_data/train", - "dir_eval": "/home/vahid/Downloads/image_classification_data/eval", - "dir_output": "/home/vahid/Downloads/image_classification_data/output" + "dir_train": "./training_data_sample_enhancement", + "dir_eval": "./eval", + "dir_output": "./out" } diff --git a/train/gt_for_enhancement_creator.py b/train/gt_for_enhancement_creator.py new file mode 100644 index 0000000..9a4274f --- /dev/null +++ b/train/gt_for_enhancement_creator.py @@ -0,0 +1,31 @@ +import cv2 +import os + +def resize_image(seg_in, input_height, input_width): + return cv2.resize(seg_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST) + + +dir_imgs = './training_data_sample_enhancement/images' +dir_out_imgs = './training_data_sample_enhancement/images_gt' +dir_out_labs = './training_data_sample_enhancement/labels_gt' + +ls_imgs = os.listdir(dir_imgs) + + +ls_scales = [ 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9] + + +for img in ls_imgs: + img_name = img.split('.')[0] + img_type = img.split('.')[1] + image = cv2.imread(os.path.join(dir_imgs, img)) + for i, scale in enumerate(ls_scales): + height_sc = int(image.shape[0]*scale) + width_sc = int(image.shape[1]*scale) + + image_down_scaled = resize_image(image, height_sc, width_sc) + image_back_to_org_scale = resize_image(image_down_scaled, image.shape[0], image.shape[1]) + + cv2.imwrite(os.path.join(dir_out_imgs, img_name+'_'+str(i)+'.'+img_type), image_back_to_org_scale) + cv2.imwrite(os.path.join(dir_out_labs, img_name+'_'+str(i)+'.'+img_type), image) + diff --git a/train/models.py b/train/models.py index a6de1ef..4cceacd 100644 --- a/train/models.py +++ b/train/models.py @@ -168,7 +168,7 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)) return x -def resnet50_unet_light(n_classes, input_height=224, input_width=224, weight_decay=1e-6, pretraining=False): +def resnet50_unet_light(n_classes, input_height=224, input_width=224, taks="segmentation", weight_decay=1e-6, pretraining=False): assert input_height % 32 == 0 assert input_width % 32 == 0 @@ -259,14 +259,17 @@ def resnet50_unet_light(n_classes, input_height=224, input_width=224, weight_dec o = Activation('relu')(o) o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = (Activation('softmax'))(o) + if task == "segmentation": + o = (BatchNormalization(axis=bn_axis))(o) + o = (Activation('softmax'))(o) + else: + o = (Activation('sigmoid'))(o) model = Model(img_input, o) return model -def resnet50_unet(n_classes, input_height=224, input_width=224, weight_decay=1e-6, pretraining=False): +def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): assert input_height % 32 == 0 assert input_width % 32 == 0 @@ -354,15 +357,18 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, weight_decay=1e- o = Activation('relu')(o) o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = (Activation('softmax'))(o) + if task == "segmentation": + o = (BatchNormalization(axis=bn_axis))(o) + o = (Activation('softmax'))(o) + else: + o = (Activation('sigmoid'))(o) model = Model(img_input, o) return model -def vit_resnet50_unet(n_classes,patch_size, num_patches, input_height=224,input_width=224,weight_decay=1e-6,pretraining=False): +def vit_resnet50_unet(n_classes, patch_size, num_patches, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): inputs = layers.Input(shape=(input_height, input_width, 3)) IMAGE_ORDERING = 'channels_last' bn_axis=3 @@ -465,8 +471,11 @@ def vit_resnet50_unet(n_classes,patch_size, num_patches, input_height=224,input_ o = Activation('relu')(o) o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = (Activation('softmax'))(o) + if task == "segmentation": + o = (BatchNormalization(axis=bn_axis))(o) + o = (Activation('softmax'))(o) + else: + o = (Activation('sigmoid'))(o) model = Model(inputs=inputs, outputs=o) diff --git a/train/train.py b/train/train.py index efcd3ac..595debe 100644 --- a/train/train.py +++ b/train/train.py @@ -1,5 +1,6 @@ import os import sys +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' import tensorflow as tf from tensorflow.compat.v1.keras.backend import set_session import warnings @@ -91,7 +92,7 @@ def run(_config, n_classes, n_epochs, input_height, num_patches_xy, model_name, flip_index, dir_eval, dir_output, pretraining, learning_rate, task, f1_threshold_classification): - if task == "segmentation": + if task == "segmentation" or "enhancement": num_patches = num_patches_xy[0]*num_patches_xy[1] if data_is_provided: @@ -153,7 +154,7 @@ def run(_config, n_classes, n_epochs, input_height, blur_aug, padding_white, padding_black, flip_aug, binarization, scaling, degrading, brightening, scales, degrade_scales, brightness, flip_index, scaling_bluring, scaling_brightness, scaling_binarization, - rotation, rotation_not_90, thetha, scaling_flip, augmentation=augmentation, + rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=augmentation, patches=patches) provide_patches(imgs_list_test, segs_list_test, dir_img_val, dir_seg_val, @@ -161,7 +162,7 @@ def run(_config, n_classes, n_epochs, input_height, blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization, scaling, degrading, brightening, scales, degrade_scales, brightness, flip_index, scaling_bluring, scaling_brightness, scaling_binarization, - rotation, rotation_not_90, thetha, scaling_flip, augmentation=False, patches=patches) + rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=False, patches=patches) if weighted_loss: weights = np.zeros(n_classes) @@ -191,45 +192,49 @@ def run(_config, n_classes, n_epochs, input_height, if continue_training: if model_name=='resnet50_unet': - if is_loss_soft_dice: + if is_loss_soft_dice and task == "segmentation": model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss}) - if weighted_loss: + if weighted_loss and task == "segmentation": model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)}) if not is_loss_soft_dice and not weighted_loss: model = load_model(dir_of_start_model , compile=True) elif model_name=='hybrid_transformer_cnn': - if is_loss_soft_dice: + if is_loss_soft_dice and task == "segmentation": model = load_model(dir_of_start_model, compile=True, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches,'soft_dice_loss': soft_dice_loss}) - if weighted_loss: + if weighted_loss and task == "segmentation": model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)}) if not is_loss_soft_dice and not weighted_loss: model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches}) else: index_start = 0 if model_name=='resnet50_unet': - model = resnet50_unet(n_classes, input_height, input_width,weight_decay,pretraining) + model = resnet50_unet(n_classes, input_height, input_width, task, weight_decay, pretraining) elif model_name=='hybrid_transformer_cnn': - model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width,weight_decay,pretraining) + model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width, task, weight_decay, pretraining) #if you want to see the model structure just uncomment model summary. #model.summary() - - if not is_loss_soft_dice and not weighted_loss: - model.compile(loss='categorical_crossentropy', - optimizer=Adam(lr=learning_rate), metrics=['accuracy']) - if is_loss_soft_dice: - model.compile(loss=soft_dice_loss, - optimizer=Adam(lr=learning_rate), metrics=['accuracy']) - if weighted_loss: - model.compile(loss=weighted_categorical_crossentropy(weights), - optimizer=Adam(lr=learning_rate), metrics=['accuracy']) + if task == "segmentation": + if not is_loss_soft_dice and not weighted_loss: + model.compile(loss='categorical_crossentropy', + optimizer=Adam(lr=learning_rate), metrics=['accuracy']) + if is_loss_soft_dice: + model.compile(loss=soft_dice_loss, + optimizer=Adam(lr=learning_rate), metrics=['accuracy']) + if weighted_loss: + model.compile(loss=weighted_categorical_crossentropy(weights), + optimizer=Adam(lr=learning_rate), metrics=['accuracy']) + elif task == "enhancement": + model.compile(loss='mean_squared_error', + optimizer=Adam(lr=learning_rate), metrics=['accuracy']) + # generating train and evaluation data train_gen = data_gen(dir_flow_train_imgs, dir_flow_train_labels, batch_size=n_batch, - input_height=input_height, input_width=input_width, n_classes=n_classes) + input_height=input_height, input_width=input_width, n_classes=n_classes, task=task) val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, batch_size=n_batch, - input_height=input_height, input_width=input_width, n_classes=n_classes) + input_height=input_height, input_width=input_width, n_classes=n_classes, task=task) ##img_validation_patches = os.listdir(dir_flow_eval_imgs) ##score_best=[] diff --git a/train/utils.py b/train/utils.py index af3c5f8..0c5a458 100644 --- a/train/utils.py +++ b/train/utils.py @@ -268,7 +268,7 @@ def IoU(Yi, y_predi): return mIoU -def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_classes): +def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_classes, task='segmentation'): c = 0 n = [f for f in os.listdir(img_folder) if not f.startswith('.')] # os.listdir(img_folder) #List of training images random.shuffle(n) @@ -277,8 +277,6 @@ def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_c mask = np.zeros((batch_size, input_height, input_width, n_classes)).astype('float') for i in range(c, c + batch_size): # initially from 0 to 16, c = 0. - # print(img_folder+'/'+n[i]) - try: filename = n[i].split('.')[0] @@ -287,11 +285,14 @@ def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_c interpolation=cv2.INTER_NEAREST) # Read an image from folder and resize img[i - c] = train_img # add to array - img[0], img[1], and so on. - train_mask = cv2.imread(mask_folder + '/' + filename + '.png') - # print(mask_folder+'/'+filename+'.png') - # print(train_mask.shape) - train_mask = get_one_hot(resize_image(train_mask, input_height, input_width), input_height, input_width, - n_classes) + if task == "segmentation": + train_mask = cv2.imread(mask_folder + '/' + filename + '.png') + train_mask = get_one_hot(resize_image(train_mask, input_height, input_width), input_height, input_width, + n_classes) + elif task == "enhancement": + train_mask = cv2.imread(mask_folder + '/' + filename + '.png')/255. + train_mask = resize_image(train_mask, input_height, input_width) + # train_mask = train_mask.reshape(224, 224, 1) # Add extra dimension for parity with train_img size [512 * 512 * 3] mask[i - c] = train_mask @@ -539,14 +540,19 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow padding_white, padding_black, flip_aug, binarization, scaling, degrading, brightening, scales, degrade_scales, brightness, flip_index, scaling_bluring, scaling_brightness, scaling_binarization, rotation, - rotation_not_90, thetha, scaling_flip, augmentation=False, patches=False): + rotation_not_90, thetha, scaling_flip, task, augmentation=False, patches=False): indexer = 0 for im, seg_i in tqdm(zip(imgs_list_train, segs_list_train)): img_name = im.split('.')[0] + if task == "segmentation": + dir_of_label_file = os.path.join(dir_seg, img_name + '.png') + elif task=="enhancement": + dir_of_label_file = os.path.join(dir_seg, im) + if not patches: cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(cv2.imread(dir_img + '/' + im), input_height, input_width)) - cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width)) + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) indexer += 1 if augmentation: @@ -556,7 +562,7 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow resize_image(cv2.flip(cv2.imread(dir_img+'/'+im),f_i),input_height,input_width) ) cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', - resize_image(cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), f_i), input_height, input_width)) + resize_image(cv2.flip(cv2.imread(dir_of_label_file), f_i), input_height, input_width)) indexer += 1 if blur_aug: @@ -565,7 +571,7 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow (resize_image(bluring(cv2.imread(dir_img + '/' + im), blur_i), input_height, input_width))) cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', - resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width)) + resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) indexer += 1 if binarization: @@ -573,26 +579,26 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow resize_image(otsu_copy(cv2.imread(dir_img + '/' + im)), input_height, input_width)) cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', - resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width)) + resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) indexer += 1 if patches: indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, - cv2.imread(dir_img + '/' + im), cv2.imread(dir_seg + '/' + img_name + '.png'), + cv2.imread(dir_img + '/' + im), cv2.imread(dir_of_label_file), input_height, input_width, indexer=indexer) if augmentation: if rotation: indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, rotation_90(cv2.imread(dir_img + '/' + im)), - rotation_90(cv2.imread(dir_seg + '/' + img_name + '.png')), + rotation_90(cv2.imread(dir_of_label_file)), input_height, input_width, indexer=indexer) if rotation_not_90: for thetha_i in thetha: img_max_rotated, label_max_rotated = rotation_not_90_func(cv2.imread(dir_img + '/'+im), - cv2.imread(dir_seg + '/'+img_name + '.png'), thetha_i) + cv2.imread(dir_of_label_file), thetha_i) indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, img_max_rotated, label_max_rotated, @@ -601,24 +607,24 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow for f_i in flip_index: indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, cv2.flip(cv2.imread(dir_img + '/' + im), f_i), - cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), f_i), + cv2.flip(cv2.imread(dir_of_label_file), f_i), input_height, input_width, indexer=indexer) if blur_aug: for blur_i in blur_k: indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, bluring(cv2.imread(dir_img + '/' + im), blur_i), - cv2.imread(dir_seg + '/' + img_name + '.png'), + cv2.imread(dir_of_label_file), input_height, input_width, indexer=indexer) if padding_black: indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, do_padding_black(cv2.imread(dir_img + '/' + im)), - do_padding_label(cv2.imread(dir_seg + '/' + img_name + '.png')), + do_padding_label(cv2.imread(dir_of_label_file)), input_height, input_width, indexer=indexer) if padding_white: indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, do_padding_white(cv2.imread(dir_img + '/'+im)), - do_padding_label(cv2.imread(dir_seg + '/' + img_name + '.png')), + do_padding_label(cv2.imread(dir_of_label_file)), input_height, input_width, indexer=indexer) if brightening: @@ -626,7 +632,7 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow try: indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, do_brightening(dir_img + '/' +im, factor), - cv2.imread(dir_seg + '/' + img_name + '.png'), + cv2.imread(dir_of_label_file), input_height, input_width, indexer=indexer) except: pass @@ -634,20 +640,20 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow for sc_ind in scales: indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels, cv2.imread(dir_img + '/' + im) , - cv2.imread(dir_seg + '/' + img_name + '.png'), + cv2.imread(dir_of_label_file), input_height, input_width, indexer=indexer, scaler=sc_ind) if degrading: for degrade_scale_ind in degrade_scales: indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, do_degrading(cv2.imread(dir_img + '/' + im), degrade_scale_ind), - cv2.imread(dir_seg + '/' + img_name + '.png'), + cv2.imread(dir_of_label_file), input_height, input_width, indexer=indexer) if binarization: indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, otsu_copy(cv2.imread(dir_img + '/' + im)), - cv2.imread(dir_seg + '/' + img_name + '.png'), + cv2.imread(dir_of_label_file), input_height, input_width, indexer=indexer) if scaling_brightness: @@ -657,7 +663,7 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels, do_brightening(dir_img + '/' + im, factor) - ,cv2.imread(dir_seg + '/' + img_name + '.png') + ,cv2.imread(dir_of_label_file) ,input_height, input_width, indexer=indexer, scaler=sc_ind) except: pass @@ -667,14 +673,14 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow for blur_i in blur_k: indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels, bluring(cv2.imread(dir_img + '/' + im), blur_i), - cv2.imread(dir_seg + '/' + img_name + '.png'), + cv2.imread(dir_of_label_file), input_height, input_width, indexer=indexer, scaler=sc_ind) if scaling_binarization: for sc_ind in scales: indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels, otsu_copy(cv2.imread(dir_img + '/' + im)), - cv2.imread(dir_seg + '/' + img_name + '.png'), + cv2.imread(dir_of_label_file), input_height, input_width, indexer=indexer, scaler=sc_ind) if scaling_flip: @@ -682,5 +688,5 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow for f_i in flip_index: indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels, cv2.flip( cv2.imread(dir_img + '/' + im), f_i), - cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), f_i), + cv2.flip(cv2.imread(dir_of_label_file), f_i), input_height, input_width, indexer=indexer, scaler=sc_ind) From e1f62c2e9827030e3386ff678a131481d70e8e14 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Tue, 7 May 2024 13:34:03 +0200 Subject: [PATCH 042/123] inference script is added --- train/config_params.json | 17 +- train/inference.py | 490 +++++++++++++++++++++++++++++++++++++++ train/train.py | 42 ++-- train/utils.py | 30 +-- 4 files changed, 537 insertions(+), 42 deletions(-) create mode 100644 train/inference.py diff --git a/train/config_params.json b/train/config_params.json index 1c7a940..8a56de5 100644 --- a/train/config_params.json +++ b/train/config_params.json @@ -1,12 +1,12 @@ { - "model_name" : "resnet50_unet", - "task": "enhancement", - "n_classes" : 3, - "n_epochs" : 3, + "backbone_type" : "nontransformer", + "task": "classification", + "n_classes" : 2, + "n_epochs" : 20, "input_height" : 448, "input_width" : 448, "weight_decay" : 1e-6, - "n_batch" : 3, + "n_batch" : 6, "learning_rate": 1e-4, "f1_threshold_classification": 0.8, "patches" : true, @@ -21,7 +21,7 @@ "scaling_flip" : false, "rotation": false, "rotation_not_90": false, - "num_patches_xy": [28, 28], + "transformer_num_patches_xy": [28, 28], "transformer_patchsize": 1, "blur_k" : ["blur","guass","median"], "scales" : [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.4], @@ -29,13 +29,14 @@ "degrade_scales" : [0.2, 0.4], "flip_index" : [0, 1, -1], "thetha" : [10, -10], + "classification_classes_name" : {"0":"apple", "1":"orange"}, "continue_training": false, "index_start" : 0, "dir_of_start_model" : " ", "weighted_loss": false, "is_loss_soft_dice": false, "data_is_provided": false, - "dir_train": "./training_data_sample_enhancement", + "dir_train": "./train", "dir_eval": "./eval", - "dir_output": "./out" + "dir_output": "./output" } diff --git a/train/inference.py b/train/inference.py new file mode 100644 index 0000000..6911bea --- /dev/null +++ b/train/inference.py @@ -0,0 +1,490 @@ +#! /usr/bin/env python3 + +__version__= '1.0' + +import argparse +import sys +import os +import numpy as np +import warnings +import xml.etree.ElementTree as et +import pandas as pd +from tqdm import tqdm +import csv +import cv2 +import seaborn as sns +import matplotlib.pyplot as plt +from tensorflow.keras.models import load_model +import tensorflow as tf +from tensorflow.keras import backend as K +from tensorflow.keras import layers +import tensorflow.keras.losses +from tensorflow.keras.layers import * +import click +import json +from tensorflow.python.keras import backend as tensorflow_backend + + + + + + +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + +__doc__=\ +""" +Tool to load model and predict for given image. +""" + +projection_dim = 64 +patch_size = 1 +num_patches =28*28 +class Patches(layers.Layer): + def __init__(self, **kwargs): + super(Patches, self).__init__() + self.patch_size = patch_size + + def call(self, images): + print(tf.shape(images)[1],'images') + print(self.patch_size,'self.patch_size') + batch_size = tf.shape(images)[0] + patches = tf.image.extract_patches( + images=images, + sizes=[1, self.patch_size, self.patch_size, 1], + strides=[1, self.patch_size, self.patch_size, 1], + rates=[1, 1, 1, 1], + padding="VALID", + ) + patch_dims = patches.shape[-1] + print(patches.shape,patch_dims,'patch_dims') + patches = tf.reshape(patches, [batch_size, -1, patch_dims]) + return patches + def get_config(self): + + config = super().get_config().copy() + config.update({ + 'patch_size': self.patch_size, + }) + return config + + +class PatchEncoder(layers.Layer): + def __init__(self, **kwargs): + super(PatchEncoder, self).__init__() + self.num_patches = num_patches + self.projection = layers.Dense(units=projection_dim) + self.position_embedding = layers.Embedding( + input_dim=num_patches, output_dim=projection_dim + ) + + def call(self, patch): + positions = tf.range(start=0, limit=self.num_patches, delta=1) + encoded = self.projection(patch) + self.position_embedding(positions) + return encoded + def get_config(self): + + config = super().get_config().copy() + config.update({ + 'num_patches': self.num_patches, + 'projection': self.projection, + 'position_embedding': self.position_embedding, + }) + return config + + +class sbb_predict: + def __init__(self,image, model, task, config_params_model, patches='false',save='false', ground_truth=None,weights_dir=None ): + self.image=image + self.patches=patches + self.save=save + self.model_dir=model + self.ground_truth=ground_truth + self.weights_dir=weights_dir + self.task=task + self.config_params_model=config_params_model + + def resize_image(self,img_in,input_height,input_width): + return cv2.resize( img_in, ( input_width,input_height) ,interpolation=cv2.INTER_NEAREST) + + + def color_images(self,seg): + ann_u=range(self.n_classes) + if len(np.shape(seg))==3: + seg=seg[:,:,0] + + seg_img=np.zeros((np.shape(seg)[0],np.shape(seg)[1],3)).astype(np.uint8) + colors=sns.color_palette("hls", self.n_classes) + + for c in ann_u: + c=int(c) + segl=(seg==c) + seg_img[:,:,0][seg==c]=c + seg_img[:,:,1][seg==c]=c + seg_img[:,:,2][seg==c]=c + return seg_img + + def otsu_copy_binary(self,img): + img_r=np.zeros((img.shape[0],img.shape[1],3)) + img1=img[:,:,0] + + #print(img.min()) + #print(img[:,:,0].min()) + #blur = cv2.GaussianBlur(img,(5,5)) + #ret3,th3 = cv2.threshold(blur,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU) + retval1, threshold1 = cv2.threshold(img1, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) + + + + img_r[:,:,0]=threshold1 + img_r[:,:,1]=threshold1 + img_r[:,:,2]=threshold1 + #img_r=img_r/float(np.max(img_r))*255 + return img_r + + def otsu_copy(self,img): + img_r=np.zeros((img.shape[0],img.shape[1],3)) + #img1=img[:,:,0] + + #print(img.min()) + #print(img[:,:,0].min()) + #blur = cv2.GaussianBlur(img,(5,5)) + #ret3,th3 = cv2.threshold(blur,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU) + _, threshold1 = cv2.threshold(img[:,:,0], 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) + _, threshold2 = cv2.threshold(img[:,:,1], 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) + _, threshold3 = cv2.threshold(img[:,:,2], 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) + + + + img_r[:,:,0]=threshold1 + img_r[:,:,1]=threshold2 + img_r[:,:,2]=threshold3 + ###img_r=img_r/float(np.max(img_r))*255 + return img_r + + def soft_dice_loss(self,y_true, y_pred, epsilon=1e-6): + + axes = tuple(range(1, len(y_pred.shape)-1)) + + numerator = 2. * K.sum(y_pred * y_true, axes) + + denominator = K.sum(K.square(y_pred) + K.square(y_true), axes) + return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch + + def weighted_categorical_crossentropy(self,weights=None): + + def loss(y_true, y_pred): + labels_floats = tf.cast(y_true, tf.float32) + per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred) + + if weights is not None: + weight_mask = tf.maximum(tf.reduce_max(tf.constant( + np.array(weights, dtype=np.float32)[None, None, None]) + * labels_floats, axis=-1), 1.0) + per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None] + return tf.reduce_mean(per_pixel_loss) + return self.loss + + + def IoU(self,Yi,y_predi): + ## mean Intersection over Union + ## Mean IoU = TP/(FN + TP + FP) + + IoUs = [] + Nclass = np.unique(Yi) + for c in Nclass: + TP = np.sum( (Yi == c)&(y_predi==c) ) + FP = np.sum( (Yi != c)&(y_predi==c) ) + FN = np.sum( (Yi == c)&(y_predi != c)) + IoU = TP/float(TP + FP + FN) + if self.n_classes>2: + print("class {:02.0f}: #TP={:6.0f}, #FP={:6.0f}, #FN={:5.0f}, IoU={:4.3f}".format(c,TP,FP,FN,IoU)) + IoUs.append(IoU) + if self.n_classes>2: + mIoU = np.mean(IoUs) + print("_________________") + print("Mean IoU: {:4.3f}".format(mIoU)) + return mIoU + elif self.n_classes==2: + mIoU = IoUs[1] + print("_________________") + print("IoU: {:4.3f}".format(mIoU)) + return mIoU + + def start_new_session_and_model(self): + + config = tf.compat.v1.ConfigProto() + config.gpu_options.allow_growth = True + + session = tf.compat.v1.Session(config=config) # tf.InteractiveSession() + tensorflow_backend.set_session(session) + #tensorflow.keras.layers.custom_layer = PatchEncoder + #tensorflow.keras.layers.custom_layer = Patches + self.model = load_model(self.model_dir , compile=False,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches}) + #config = tf.ConfigProto() + #config.gpu_options.allow_growth=True + + #self.session = tf.InteractiveSession() + #keras.losses.custom_loss = self.weighted_categorical_crossentropy + #self.model = load_model(self.model_dir , compile=False) + + + ##if self.weights_dir!=None: + ##self.model.load_weights(self.weights_dir) + + if self.task != 'classification': + self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1] + self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2] + self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3] + + def visualize_model_output(self, prediction, img, task): + if task == "binarization": + prediction = prediction * -1 + prediction = prediction + 1 + added_image = prediction * 255 + else: + unique_classes = np.unique(prediction[:,:,0]) + rgb_colors = {'0' : [255, 255, 255], + '1' : [255, 0, 0], + '2' : [255, 125, 0], + '3' : [255, 0, 125], + '4' : [125, 125, 125], + '5' : [125, 125, 0], + '6' : [0, 125, 255], + '7' : [0, 125, 0], + '8' : [125, 125, 125], + '9' : [0, 125, 255], + '10' : [125, 0, 125], + '11' : [0, 255, 0], + '12' : [0, 0, 255], + '13' : [0, 255, 255], + '14' : [255, 125, 125], + '15' : [255, 0, 255]} + + output = np.zeros(prediction.shape) + + for unq_class in unique_classes: + rgb_class_unique = rgb_colors[str(int(unq_class))] + output[:,:,0][prediction[:,:,0]==unq_class] = rgb_class_unique[0] + output[:,:,1][prediction[:,:,0]==unq_class] = rgb_class_unique[1] + output[:,:,2][prediction[:,:,0]==unq_class] = rgb_class_unique[2] + + + + img = self.resize_image(img, output.shape[0], output.shape[1]) + + output = output.astype(np.int32) + img = img.astype(np.int32) + + + + added_image = cv2.addWeighted(img,0.5,output,0.1,0) + + return added_image + + def predict(self): + self.start_new_session_and_model() + if self.task == 'classification': + classes_names = self.config_params_model['classification_classes_name'] + img_1ch = img=cv2.imread(self.image, 0) + + img_1ch = img_1ch / 255.0 + img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'], self.config_params_model['input_width']), interpolation=cv2.INTER_NEAREST) + img_in = np.zeros((1, img_1ch.shape[0], img_1ch.shape[1], 3)) + img_in[0, :, :, 0] = img_1ch[:, :] + img_in[0, :, :, 1] = img_1ch[:, :] + img_in[0, :, :, 2] = img_1ch[:, :] + + label_p_pred = self.model.predict(img_in, verbose=0) + index_class = np.argmax(label_p_pred[0]) + + print("Predicted Class: {}".format(classes_names[str(int(index_class))])) + else: + if self.patches: + #def textline_contours(img,input_width,input_height,n_classes,model): + + img=cv2.imread(self.image) + self.img_org = np.copy(img) + + if img.shape[0] < self.img_height: + img = cv2.resize(img, (img.shape[1], self.img_width), interpolation=cv2.INTER_NEAREST) + + if img.shape[1] < self.img_width: + img = cv2.resize(img, (self.img_height, img.shape[0]), interpolation=cv2.INTER_NEAREST) + margin = int(0 * self.img_width) + width_mid = self.img_width - 2 * margin + height_mid = self.img_height - 2 * margin + img = img / float(255.0) + + img_h = img.shape[0] + img_w = img.shape[1] + + prediction_true = np.zeros((img_h, img_w, 3)) + nxf = img_w / float(width_mid) + nyf = img_h / float(height_mid) + + nxf = int(nxf) + 1 if nxf > int(nxf) else int(nxf) + nyf = int(nyf) + 1 if nyf > int(nyf) else int(nyf) + + for i in range(nxf): + for j in range(nyf): + if i == 0: + index_x_d = i * width_mid + index_x_u = index_x_d + self.img_width + else: + index_x_d = i * width_mid + index_x_u = index_x_d + self.img_width + if j == 0: + index_y_d = j * height_mid + index_y_u = index_y_d + self.img_height + else: + index_y_d = j * height_mid + index_y_u = index_y_d + self.img_height + + if index_x_u > img_w: + index_x_u = img_w + index_x_d = img_w - self.img_width + if index_y_u > img_h: + index_y_u = img_h + index_y_d = img_h - self.img_height + + img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :] + label_p_pred = self.model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2]), + verbose=0) + + if self.task == 'enhancement': + seg = label_p_pred[0, :, :, :] + seg = seg * 255 + elif self.task == 'segmentation' or self.task == 'binarization': + seg = np.argmax(label_p_pred, axis=3)[0] + seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2) + + + if i == 0 and j == 0: + seg = seg[0 : seg.shape[0] - margin, 0 : seg.shape[1] - margin] + prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + 0 : index_x_u - margin, :] = seg + elif i == nxf - 1 and j == nyf - 1: + seg = seg[margin : seg.shape[0] - 0, margin : seg.shape[1] - 0] + prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + margin : index_x_u - 0, :] = seg + elif i == 0 and j == nyf - 1: + seg = seg[margin : seg.shape[0] - 0, 0 : seg.shape[1] - margin] + prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + 0 : index_x_u - margin, :] = seg + elif i == nxf - 1 and j == 0: + seg = seg[0 : seg.shape[0] - margin, margin : seg.shape[1] - 0] + prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + margin : index_x_u - 0, :] = seg + elif i == 0 and j != 0 and j != nyf - 1: + seg = seg[margin : seg.shape[0] - margin, 0 : seg.shape[1] - margin] + prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + 0 : index_x_u - margin, :] = seg + elif i == nxf - 1 and j != 0 and j != nyf - 1: + seg = seg[margin : seg.shape[0] - margin, margin : seg.shape[1] - 0] + prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + margin : index_x_u - 0, :] = seg + elif i != 0 and i != nxf - 1 and j == 0: + seg = seg[0 : seg.shape[0] - margin, margin : seg.shape[1] - margin] + prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + margin : index_x_u - margin, :] = seg + elif i != 0 and i != nxf - 1 and j == nyf - 1: + seg = seg[margin : seg.shape[0] - 0, margin : seg.shape[1] - margin] + prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + margin : index_x_u - margin, :] = seg + else: + seg = seg[margin : seg.shape[0] - margin, margin : seg.shape[1] - margin] + prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + margin : index_x_u - margin, :] = seg + prediction_true = prediction_true.astype(int) + prediction_true = cv2.resize(prediction_true, (self.img_org.shape[1], self.img_org.shape[0]), interpolation=cv2.INTER_NEAREST) + return prediction_true + + else: + + img=cv2.imread(self.image) + self.img_org = np.copy(img) + + width=self.img_width + height=self.img_height + + img=img/255.0 + img=self.resize_image(img,self.img_height,self.img_width) + + + label_p_pred=self.model.predict( + img.reshape(1,img.shape[0],img.shape[1],img.shape[2])) + + if self.task == 'enhancement': + seg = label_p_pred[0, :, :, :] + seg = seg * 255 + elif self.task == 'segmentation' or self.task == 'binarization': + seg = np.argmax(label_p_pred, axis=3)[0] + seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2) + + prediction_true = seg.astype(int) + + prediction_true = cv2.resize(prediction_true, (self.img_org.shape[1], self.img_org.shape[0]), interpolation=cv2.INTER_NEAREST) + return prediction_true + + + + def run(self): + res=self.predict() + if self.task == 'classification': + pass + else: + img_seg_overlayed = self.visualize_model_output(res, self.img_org, self.task) + cv2.imwrite('./test.png',img_seg_overlayed) + ##if self.save!=None: + ##img=np.repeat(res[:, :, np.newaxis]*255, 3, axis=2) + ##cv2.imwrite(self.save,img) + + ###if self.ground_truth!=None: + ###gt_img=cv2.imread(self.ground_truth) + ###self.IoU(gt_img[:,:,0],res) + ##plt.imshow(res) + ##plt.show() + +@click.command() +@click.option( + "--image", + "-i", + help="image filename", + type=click.Path(exists=True, dir_okay=False), +) +@click.option( + "--patches/--no-patches", + "-p/-nop", + is_flag=True, + help="if this parameter set to true, this tool will try to do inference in patches.", +) +@click.option( + "--save", + "-s", + help="save prediction as a png file in current folder.", +) +@click.option( + "--model", + "-m", + help="directory of models", + type=click.Path(exists=True, file_okay=False), + required=True, +) +@click.option( + "--ground_truth/--no-ground_truth", + "-gt/-nogt", + is_flag=True, + help="ground truth directory if you want to see the iou of prediction.", +) +@click.option( + "--model_weights/--no-model_weights", + "-mw/-nomw", + is_flag=True, + help="previous model weights which are saved.", +) +def main(image, model, patches, save, ground_truth, model_weights): + + with open(os.path.join(model,'config.json')) as f: + config_params_model = json.load(f) + task = 'classification' + x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth, model_weights) + x.run() + +if __name__=="__main__": + main() + + + + diff --git a/train/train.py b/train/train.py index 595debe..28363d2 100644 --- a/train/train.py +++ b/train/train.py @@ -69,7 +69,7 @@ def config_params(): flip_index = None # Flip image for augmentation. continue_training = False # Set to true if you would like to continue training an already trained a model. transformer_patchsize = None # Patch size of vision transformer patches. - num_patches_xy = None # Number of patches for vision transformer. + transformer_num_patches_xy = None # Number of patches for vision transformer. index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3. dir_of_start_model = '' # Directory containing pretrained encoder to continue training the model. is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false. @@ -77,6 +77,8 @@ def config_params(): data_is_provided = False # Only set this to true when you have already provided the input data and the train and eval data are in "dir_output". task = "segmentation" # This parameter defines task of model which can be segmentation, enhancement or classification. f1_threshold_classification = None # This threshold is used to consider models with an evaluation f1 scores bigger than it. The selected model weights undergo a weights ensembling. And avreage ensembled model will be written to output. + classification_classes_name = None # Dictionary of classification classes names. + backbone_type = None # As backbone we have 2 types of backbones. A vision transformer alongside a CNN and we call it "transformer" and only CNN called "nontransformer" @ex.automain @@ -89,12 +91,12 @@ def run(_config, n_classes, n_epochs, input_height, brightness, dir_train, data_is_provided, scaling_bluring, scaling_brightness, scaling_binarization, rotation, rotation_not_90, thetha, scaling_flip, continue_training, transformer_patchsize, - num_patches_xy, model_name, flip_index, dir_eval, dir_output, - pretraining, learning_rate, task, f1_threshold_classification): + transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output, + pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name): - if task == "segmentation" or "enhancement": + if task == "segmentation" or task == "enhancement": - num_patches = num_patches_xy[0]*num_patches_xy[1] + num_patches = transformer_num_patches_xy[0]*transformer_num_patches_xy[1] if data_is_provided: dir_train_flowing = os.path.join(dir_output, 'train') dir_eval_flowing = os.path.join(dir_output, 'eval') @@ -191,14 +193,14 @@ def run(_config, n_classes, n_epochs, input_height, weights = weights / float(np.sum(weights)) if continue_training: - if model_name=='resnet50_unet': + if backbone_type=='nontransformer': if is_loss_soft_dice and task == "segmentation": model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss}) if weighted_loss and task == "segmentation": model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)}) if not is_loss_soft_dice and not weighted_loss: model = load_model(dir_of_start_model , compile=True) - elif model_name=='hybrid_transformer_cnn': + elif backbone_type=='transformer': if is_loss_soft_dice and task == "segmentation": model = load_model(dir_of_start_model, compile=True, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches,'soft_dice_loss': soft_dice_loss}) if weighted_loss and task == "segmentation": @@ -207,9 +209,9 @@ def run(_config, n_classes, n_epochs, input_height, model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches}) else: index_start = 0 - if model_name=='resnet50_unet': + if backbone_type=='nontransformer': model = resnet50_unet(n_classes, input_height, input_width, task, weight_decay, pretraining) - elif model_name=='hybrid_transformer_cnn': + elif backbone_type=='nontransformer': model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width, task, weight_decay, pretraining) #if you want to see the model structure just uncomment model summary. @@ -246,9 +248,9 @@ def run(_config, n_classes, n_epochs, input_height, validation_data=val_gen, validation_steps=1, epochs=1) - model.save(dir_output+'/'+'model_'+str(i)) + model.save(os.path.join(dir_output,'model_'+str(i))) - with open(dir_output+'/'+'model_'+str(i)+'/'+"config.json", "w") as fp: + with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp: json.dump(_config, fp) # encode dict into JSON #os.system('rm -rf '+dir_train_flowing) @@ -257,14 +259,15 @@ def run(_config, n_classes, n_epochs, input_height, #model.save(dir_output+'/'+'model'+'.h5') elif task=='classification': configuration() - model = resnet50_classifier(n_classes, input_height, input_width,weight_decay,pretraining) + model = resnet50_classifier(n_classes, input_height, input_width, weight_decay, pretraining) opt_adam = Adam(learning_rate=0.001) model.compile(loss='categorical_crossentropy', optimizer = opt_adam,metrics=['accuracy']) - - testX, testY = generate_data_from_folder_evaluation(dir_eval, input_height, input_width, n_classes) + + list_classes = list(classification_classes_name.values()) + testX, testY = generate_data_from_folder_evaluation(dir_eval, input_height, input_width, n_classes, list_classes) #print(testY.shape, testY) @@ -280,7 +283,7 @@ def run(_config, n_classes, n_epochs, input_height, for i in range(n_epochs): #history = model.fit(trainX, trainY, epochs=1, batch_size=n_batch, validation_data=(testX, testY), verbose=2)#,class_weight=weights) - history = model.fit( generate_data_from_folder_training(dir_train, n_batch , input_height, input_width, n_classes), steps_per_epoch=num_rows / n_batch, verbose=0)#,class_weight=weights) + history = model.fit( generate_data_from_folder_training(dir_train, n_batch , input_height, input_width, n_classes, list_classes), steps_per_epoch=num_rows / n_batch, verbose=0)#,class_weight=weights) y_pr_class = [] for jj in range(testY.shape[0]): @@ -301,10 +304,6 @@ def run(_config, n_classes, n_epochs, input_height, score_best[0]=f1score model.save(os.path.join(dir_output,'model_best')) - - ##best_model=keras.models.clone_model(model) - ##best_model.build() - ##best_model.set_weights(model.get_weights()) if f1score > f1_threshold_classification: weights.append(model.get_weights() ) y_tot=y_tot+y_pr @@ -329,4 +328,9 @@ def run(_config, n_classes, n_epochs, input_height, ##best_model.save('model_taza.h5') model_weight_averaged.save(os.path.join(dir_output,'model_ens_avg')) + with open(os.path.join( os.path.join(dir_output,'model_ens_avg'), "config.json"), "w") as fp: + json.dump(_config, fp) # encode dict into JSON + + with open(os.path.join( os.path.join(dir_output,'model_best'), "config.json"), "w") as fp: + json.dump(_config, fp) # encode dict into JSON diff --git a/train/utils.py b/train/utils.py index 0c5a458..3a0375a 100644 --- a/train/utils.py +++ b/train/utils.py @@ -21,14 +21,14 @@ def return_number_of_total_training_data(path_classes): -def generate_data_from_folder_evaluation(path_classes, height, width, n_classes): - sub_classes = os.listdir(path_classes) +def generate_data_from_folder_evaluation(path_classes, height, width, n_classes, list_classes): + #sub_classes = os.listdir(path_classes) #n_classes = len(sub_classes) all_imgs = [] labels = [] - dicts =dict() - indexer= 0 - for sub_c in sub_classes: + #dicts =dict() + #indexer= 0 + for indexer, sub_c in enumerate(list_classes): sub_files = os.listdir(os.path.join(path_classes,sub_c )) sub_files = [os.path.join(path_classes,sub_c )+'/' + x for x in sub_files] #print( os.listdir(os.path.join(path_classes,sub_c )) ) @@ -37,8 +37,8 @@ def generate_data_from_folder_evaluation(path_classes, height, width, n_classes) #print( len(sub_labels) ) labels = labels + sub_labels - dicts[sub_c] = indexer - indexer +=1 + #dicts[sub_c] = indexer + #indexer +=1 categories = to_categorical(range(n_classes)).astype(np.int16)#[ [1 , 0, 0 , 0 , 0 , 0] , [0 , 1, 0 , 0 , 0 , 0] , [0 , 0, 1 , 0 , 0 , 0] , [0 , 0, 0 , 1 , 0 , 0] , [0 , 0, 0 , 0 , 1 , 0] , [0 , 0, 0 , 0 , 0 , 1] ] @@ -64,15 +64,15 @@ def generate_data_from_folder_evaluation(path_classes, height, width, n_classes) return ret_x/255., ret_y -def generate_data_from_folder_training(path_classes, batchsize, height, width, n_classes): - sub_classes = os.listdir(path_classes) - n_classes = len(sub_classes) +def generate_data_from_folder_training(path_classes, batchsize, height, width, n_classes, list_classes): + #sub_classes = os.listdir(path_classes) + #n_classes = len(sub_classes) all_imgs = [] labels = [] - dicts =dict() - indexer= 0 - for sub_c in sub_classes: + #dicts =dict() + #indexer= 0 + for indexer, sub_c in enumerate(list_classes): sub_files = os.listdir(os.path.join(path_classes,sub_c )) sub_files = [os.path.join(path_classes,sub_c )+'/' + x for x in sub_files] #print( os.listdir(os.path.join(path_classes,sub_c )) ) @@ -81,8 +81,8 @@ def generate_data_from_folder_training(path_classes, batchsize, height, width, n #print( len(sub_labels) ) labels = labels + sub_labels - dicts[sub_c] = indexer - indexer +=1 + #dicts[sub_c] = indexer + #indexer +=1 ids = np.array(range(len(labels))) random.shuffle(ids) From bc2ca7180208a780d2d34710b66bac379a096385 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Tue, 7 May 2024 16:24:12 +0200 Subject: [PATCH 043/123] modifications --- train/inference.py | 108 +++++++-------------------------------------- 1 file changed, 17 insertions(+), 91 deletions(-) diff --git a/train/inference.py b/train/inference.py index 6911bea..94e318d 100644 --- a/train/inference.py +++ b/train/inference.py @@ -1,25 +1,16 @@ -#! /usr/bin/env python3 - -__version__= '1.0' - -import argparse import sys import os import numpy as np import warnings -import xml.etree.ElementTree as et -import pandas as pd -from tqdm import tqdm -import csv import cv2 import seaborn as sns -import matplotlib.pyplot as plt from tensorflow.keras.models import load_model import tensorflow as tf from tensorflow.keras import backend as K from tensorflow.keras import layers import tensorflow.keras.losses from tensorflow.keras.layers import * +from models import * import click import json from tensorflow.python.keras import backend as tensorflow_backend @@ -37,70 +28,13 @@ __doc__=\ Tool to load model and predict for given image. """ -projection_dim = 64 -patch_size = 1 -num_patches =28*28 -class Patches(layers.Layer): - def __init__(self, **kwargs): - super(Patches, self).__init__() - self.patch_size = patch_size - - def call(self, images): - print(tf.shape(images)[1],'images') - print(self.patch_size,'self.patch_size') - batch_size = tf.shape(images)[0] - patches = tf.image.extract_patches( - images=images, - sizes=[1, self.patch_size, self.patch_size, 1], - strides=[1, self.patch_size, self.patch_size, 1], - rates=[1, 1, 1, 1], - padding="VALID", - ) - patch_dims = patches.shape[-1] - print(patches.shape,patch_dims,'patch_dims') - patches = tf.reshape(patches, [batch_size, -1, patch_dims]) - return patches - def get_config(self): - - config = super().get_config().copy() - config.update({ - 'patch_size': self.patch_size, - }) - return config - - -class PatchEncoder(layers.Layer): - def __init__(self, **kwargs): - super(PatchEncoder, self).__init__() - self.num_patches = num_patches - self.projection = layers.Dense(units=projection_dim) - self.position_embedding = layers.Embedding( - input_dim=num_patches, output_dim=projection_dim - ) - - def call(self, patch): - positions = tf.range(start=0, limit=self.num_patches, delta=1) - encoded = self.projection(patch) + self.position_embedding(positions) - return encoded - def get_config(self): - - config = super().get_config().copy() - config.update({ - 'num_patches': self.num_patches, - 'projection': self.projection, - 'position_embedding': self.position_embedding, - }) - return config - - class sbb_predict: - def __init__(self,image, model, task, config_params_model, patches='false',save='false', ground_truth=None,weights_dir=None ): + def __init__(self,image, model, task, config_params_model, patches, save, ground_truth): self.image=image self.patches=patches self.save=save self.model_dir=model self.ground_truth=ground_truth - self.weights_dir=weights_dir self.task=task self.config_params_model=config_params_model @@ -426,16 +360,12 @@ class sbb_predict: pass else: img_seg_overlayed = self.visualize_model_output(res, self.img_org, self.task) - cv2.imwrite('./test.png',img_seg_overlayed) - ##if self.save!=None: - ##img=np.repeat(res[:, :, np.newaxis]*255, 3, axis=2) - ##cv2.imwrite(self.save,img) - - ###if self.ground_truth!=None: - ###gt_img=cv2.imread(self.ground_truth) - ###self.IoU(gt_img[:,:,0],res) - ##plt.imshow(res) - ##plt.show() + if self.save: + cv2.imwrite(self.save,img_seg_overlayed) + + if self.ground_truth: + gt_img=cv2.imread(self.ground_truth) + self.IoU(gt_img[:,:,0],res[:,:,0]) @click.command() @click.option( @@ -463,23 +393,19 @@ class sbb_predict: required=True, ) @click.option( - "--ground_truth/--no-ground_truth", - "-gt/-nogt", - is_flag=True, + "--ground_truth", + "-gt", help="ground truth directory if you want to see the iou of prediction.", ) -@click.option( - "--model_weights/--no-model_weights", - "-mw/-nomw", - is_flag=True, - help="previous model weights which are saved.", -) -def main(image, model, patches, save, ground_truth, model_weights): - +def main(image, model, patches, save, ground_truth): with open(os.path.join(model,'config.json')) as f: config_params_model = json.load(f) - task = 'classification' - x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth, model_weights) + task = config_params_model['task'] + if task != 'classification': + if not save: + print("Error: You used one of segmentation or binarization task but not set -s, you need a filename to save visualized output with -s") + sys.exit(1) + x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth) x.run() if __name__=="__main__": From 241cb907cbb691988866011fdad5af12eb4986ae Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Wed, 8 May 2024 14:47:16 +0200 Subject: [PATCH 044/123] Update train.py avoid ensembling if no model weights met the threshold f1 score in the case of classification --- train/train.py | 46 +++++++++++++--------------------------------- 1 file changed, 13 insertions(+), 33 deletions(-) diff --git a/train/train.py b/train/train.py index 28363d2..78974d3 100644 --- a/train/train.py +++ b/train/train.py @@ -268,36 +268,26 @@ def run(_config, n_classes, n_epochs, input_height, list_classes = list(classification_classes_name.values()) testX, testY = generate_data_from_folder_evaluation(dir_eval, input_height, input_width, n_classes, list_classes) - - #print(testY.shape, testY) y_tot=np.zeros((testX.shape[0],n_classes)) - indexer=0 score_best=[] score_best.append(0) num_rows = return_number_of_total_training_data(dir_train) - weights=[] for i in range(n_epochs): - #history = model.fit(trainX, trainY, epochs=1, batch_size=n_batch, validation_data=(testX, testY), verbose=2)#,class_weight=weights) - history = model.fit( generate_data_from_folder_training(dir_train, n_batch , input_height, input_width, n_classes, list_classes), steps_per_epoch=num_rows / n_batch, verbose=0)#,class_weight=weights) + history = model.fit( generate_data_from_folder_training(dir_train, n_batch , input_height, input_width, n_classes, list_classes), steps_per_epoch=num_rows / n_batch, verbose=1)#,class_weight=weights) y_pr_class = [] for jj in range(testY.shape[0]): y_pr=model.predict(testX[jj,:,:,:].reshape(1,input_height,input_width,3), verbose=0) y_pr_ind= np.argmax(y_pr,axis=1) - #print(y_pr_ind, 'y_pr_ind') y_pr_class.append(y_pr_ind) - y_pr_class = np.array(y_pr_class) - #model.save('./models_save/model_'+str(i)+'.h5') - #y_pr_class=np.argmax(y_pr,axis=1) f1score=f1_score(np.argmax(testY,axis=1), y_pr_class, average='macro') - print(i,f1score) if f1score>score_best[0]: @@ -306,30 +296,20 @@ def run(_config, n_classes, n_epochs, input_height, if f1score > f1_threshold_classification: weights.append(model.get_weights() ) - y_tot=y_tot+y_pr - indexer+=1 - y_tot=y_tot/float(indexer) - - new_weights=list() - - for weights_list_tuple in zip(*weights): - new_weights.append( [np.array(weights_).mean(axis=0) for weights_ in zip(*weights_list_tuple)] ) - - new_weights = [np.array(x) for x in new_weights] - - model_weight_averaged=tf.keras.models.clone_model(model) - - model_weight_averaged.set_weights(new_weights) - - #y_tot_end=np.argmax(y_tot,axis=1) - #print(f1_score(np.argmax(testY,axis=1), y_tot_end, average='macro')) - - ##best_model.save('model_taza.h5') - model_weight_averaged.save(os.path.join(dir_output,'model_ens_avg')) - with open(os.path.join( os.path.join(dir_output,'model_ens_avg'), "config.json"), "w") as fp: - json.dump(_config, fp) # encode dict into JSON + if len(weights) >= 1: + new_weights=list() + for weights_list_tuple in zip(*weights): + new_weights.append( [np.array(weights_).mean(axis=0) for weights_ in zip(*weights_list_tuple)] ) + + new_weights = [np.array(x) for x in new_weights] + model_weight_averaged=tf.keras.models.clone_model(model) + model_weight_averaged.set_weights(new_weights) + + model_weight_averaged.save(os.path.join(dir_output,'model_ens_avg')) + with open(os.path.join( os.path.join(dir_output,'model_ens_avg'), "config.json"), "w") as fp: + json.dump(_config, fp) # encode dict into JSON with open(os.path.join( os.path.join(dir_output,'model_best'), "config.json"), "w") as fp: json.dump(_config, fp) # encode dict into JSON From d277ec4b31dd28a3da3d38e9f9fd37b5c3e17fb2 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Sun, 12 May 2024 08:32:28 +0200 Subject: [PATCH 045/123] Update utils.py --- train/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/train/utils.py b/train/utils.py index 3a0375a..271d977 100644 --- a/train/utils.py +++ b/train/utils.py @@ -9,6 +9,7 @@ from tqdm import tqdm import imutils import math from tensorflow.keras.utils import to_categorical +from PIL import Image, ImageEnhance def return_number_of_total_training_data(path_classes): From d6a057ba702f31c03db0401ab97fcd1a444b89a0 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Thu, 16 May 2024 15:03:23 +0200 Subject: [PATCH 046/123] adding page xml to label generator --- train/pagexml2label.py | 1009 ++++++++++++++++++++++++++++++++++++++++ train/requirements.txt | 1 + 2 files changed, 1010 insertions(+) create mode 100644 train/pagexml2label.py diff --git a/train/pagexml2label.py b/train/pagexml2label.py new file mode 100644 index 0000000..715f99f --- /dev/null +++ b/train/pagexml2label.py @@ -0,0 +1,1009 @@ +import click +import sys +import os +import numpy as np +import warnings +import xml.etree.ElementTree as ET +from tqdm import tqdm +import cv2 +from shapely import geometry + +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + +__doc__=\ +""" +tool to extract 2d or 3d RGB images from page xml data. In former case output will be 1 +2D image array which each class has filled with a pixel value. In the case of 3D RGB image +each class will be defined with a RGB value and beside images a text file of classes also will be produced. +This classes.txt file is required for dhsegment tool. +""" +KERNEL = np.ones((5, 5), np.uint8) + +class pagexml2word: + def __init__(self,dir_in, out_dir,output_type,experiment): + self.dir=dir_in + self.output_dir=out_dir + self.output_type=output_type + self.experiment=experiment + + def get_content_of_dir(self): + """ + Listing all ground truth page xml files. All files are needed to have xml format. + """ + + gt_all=os.listdir(self.dir) + self.gt_list=[file for file in gt_all if file.split('.')[ len(file.split('.'))-1 ]=='xml' ] + + def return_parent_contours(self,contours, hierarchy): + contours_parent = [contours[i] for i in range(len(contours)) if hierarchy[0][i][3] == -1] + return contours_parent + def filter_contours_area_of_image_tables(self,image, contours, hierarchy, max_area, min_area): + found_polygons_early = list() + + jv = 0 + for c in contours: + if len(c) < 3: # A polygon cannot have less than 3 points + continue + + polygon = geometry.Polygon([point[0] for point in c]) + # area = cv2.contourArea(c) + area = polygon.area + ##print(np.prod(thresh.shape[:2])) + # Check that polygon has area greater than minimal area + # print(hierarchy[0][jv][3],hierarchy ) + if area >= min_area * np.prod(image.shape[:2]) and area <= max_area * np.prod(image.shape[:2]): # and hierarchy[0][jv][3]==-1 : + # print(c[0][0][1]) + found_polygons_early.append(np.array([[point] for point in polygon.exterior.coords], dtype=np.int32)) + jv += 1 + return found_polygons_early + + def return_contours_of_interested_region(self,region_pre_p, pixel, min_area=0.0002): + + # pixels of images are identified by 5 + if len(region_pre_p.shape) == 3: + cnts_images = (region_pre_p[:, :, 0] == pixel) * 1 + else: + cnts_images = (region_pre_p[:, :] == pixel) * 1 + cnts_images = cnts_images.astype(np.uint8) + cnts_images = np.repeat(cnts_images[:, :, np.newaxis], 3, axis=2) + imgray = cv2.cvtColor(cnts_images, cv2.COLOR_BGR2GRAY) + ret, thresh = cv2.threshold(imgray, 0, 255, 0) + + contours_imgs, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + + contours_imgs = self.return_parent_contours(contours_imgs, hierarchy) + contours_imgs = self.filter_contours_area_of_image_tables(thresh, contours_imgs, hierarchy, max_area=1, min_area=min_area) + + return contours_imgs + + def get_images_of_ground_truth(self): + """ + Reading the page xml files and write the ground truth images into given output directory. + """ + for index in tqdm(range(len(self.gt_list))): + #try: + tree1 = ET.parse(self.dir+'/'+self.gt_list[index]) + root1=tree1.getroot() + alltags=[elem.tag for elem in root1.iter()] + link=alltags[0].split('}')[0]+'}' + + + + for jj in root1.iter(link+'Page'): + y_len=int(jj.attrib['imageHeight']) + x_len=int(jj.attrib['imageWidth']) + if self.experiment=='word': + region_tags=np.unique([x for x in alltags if x.endswith('Word')]) + co_word=[] + + for tag in region_tags: + if tag.endswith('}Word') or tag.endswith('}word'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_word.append(np.array(c_t_in)) + + img = np.zeros( (y_len,x_len, 3) ) + if self.output_type == '2d': + img_poly=cv2.fillPoly(img, pts =co_word, color=(1,1,1)) + elif self.output_type == '3d': + img_poly=cv2.fillPoly(img, pts =co_word, color=(255,0,0)) + + try: + cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('-')[1].split('.')[0]+'.png',img_poly ) + except: + cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.png',img_poly ) + + + elif self.experiment=='glyph': + region_tags=np.unique([x for x in alltags if x.endswith('Glyph')]) + co_glyph=[] + + for tag in region_tags: + if tag.endswith('}Glyph') or tag.endswith('}glyph'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_glyph.append(np.array(c_t_in)) + + img = np.zeros( (y_len,x_len, 3) ) + if self.output_type == '2d': + img_poly=cv2.fillPoly(img, pts =co_glyph, color=(1,1,1)) + elif self.output_type == '3d': + img_poly=cv2.fillPoly(img, pts =co_glyph, color=(255,0,0)) + + try: + cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('-')[1].split('.')[0]+'.png',img_poly ) + except: + cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.png',img_poly ) + + elif self.experiment=='textline': + region_tags=np.unique([x for x in alltags if x.endswith('TextLine')]) + co_line=[] + + for tag in region_tags: + if tag.endswith('}TextLine') or tag.endswith('}textline'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_line.append(np.array(c_t_in)) + + img = np.zeros( (y_len,x_len, 3) ) + if self.output_type == '2d': + img_poly=cv2.fillPoly(img, pts =co_line, color=(1,1,1)) + elif self.output_type == '3d': + img_poly=cv2.fillPoly(img, pts =co_line, color=(255,0,0)) + + try: + cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('-')[1].split('.')[0]+'.png',img_poly ) + except: + cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.png',img_poly ) + + elif self.experiment=='layout_for_main_regions': + region_tags=np.unique([x for x in alltags if x.endswith('Region')]) + #print(region_tags) + co_text=[] + co_sep=[] + co_img=[] + #co_graphic=[] + + for tag in region_tags: + if tag.endswith('}TextRegion') or tag.endswith('}Textregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_text.append(np.array(c_t_in)) + + elif tag.endswith('}ImageRegion') or tag.endswith('}GraphicRegion') or tag.endswith('}imageregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_img.append(np.array(c_t_in)) + + elif tag.endswith('}SeparatorRegion') or tag.endswith('}separatorregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_sep.append(np.array(c_t_in)) + + + + img = np.zeros( (y_len,x_len,3) ) + + if self.output_type == '3d': + img_poly=cv2.fillPoly(img, pts =co_text, color=(255,0,0)) + img_poly=cv2.fillPoly(img, pts =co_img, color=(0,255,0)) + img_poly=cv2.fillPoly(img, pts =co_sep, color=(0,0,255)) + ##img_poly=cv2.fillPoly(img, pts =co_graphic, color=(255,125,125)) + elif self.output_type == '2d': + img_poly=cv2.fillPoly(img, pts =co_text, color=(1,1,1)) + img_poly=cv2.fillPoly(img, pts =co_img, color=(2,2,2)) + img_poly=cv2.fillPoly(img, pts =co_sep, color=(3,3,3)) + + try: + cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('-')[1].split('.')[0]+'.png',img_poly ) + except: + cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.png',img_poly ) + + elif self.experiment=='textregion': + region_tags=np.unique([x for x in alltags if x.endswith('TextRegion')]) + co_textregion=[] + + for tag in region_tags: + if tag.endswith('}TextRegion') or tag.endswith('}Textregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_textregion.append(np.array(c_t_in)) + + img = np.zeros( (y_len,x_len,3) ) + if self.output_type == '3d': + img_poly=cv2.fillPoly(img, pts =co_textregion, color=(255,0,0)) + elif self.output_type == '2d': + img_poly=cv2.fillPoly(img, pts =co_textregion, color=(1,1,1)) + + + try: + cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('-')[1].split('.')[0]+'.png',img_poly ) + except: + cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.png',img_poly ) + + elif self.experiment=='layout': + region_tags=np.unique([x for x in alltags if x.endswith('Region')]) + + co_text_paragraph=[] + co_text_drop=[] + co_text_heading=[] + co_text_header=[] + co_text_marginalia=[] + co_text_catch=[] + co_text_page_number=[] + co_text_signature_mark=[] + co_sep=[] + co_img=[] + co_table=[] + co_graphic=[] + co_graphic_text_annotation=[] + co_graphic_decoration=[] + co_noise=[] + + for tag in region_tags: + if tag.endswith('}TextRegion') or tag.endswith('}Textregion'): + for nn in root1.iter(tag): + c_t_in_drop=[] + c_t_in_paragraph=[] + c_t_in_heading=[] + c_t_in_header=[] + c_t_in_page_number=[] + c_t_in_signature_mark=[] + c_t_in_catch=[] + c_t_in_marginalia=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + + coords=bool(vv.attrib) + if coords: + #print('birda1') + p_h=vv.attrib['points'].split(' ') + + + + if "type" in nn.attrib and nn.attrib['type']=='drop-capital': + #if nn.attrib['type']=='paragraph': + + c_t_in_drop.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + elif "type" in nn.attrib and nn.attrib['type']=='heading': + c_t_in_heading.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + + elif "type" in nn.attrib and nn.attrib['type']=='signature-mark': + + c_t_in_signature_mark.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + #print(c_t_in_paragraph) + elif "type" in nn.attrib and nn.attrib['type']=='header': + c_t_in_header.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + + elif "type" in nn.attrib and nn.attrib['type']=='catch-word': + c_t_in_catch.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + + elif "type" in nn.attrib and nn.attrib['type']=='page-number': + + c_t_in_page_number.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + #print(c_t_in_paragraph) + + elif "type" in nn.attrib and nn.attrib['type']=='marginalia': + + c_t_in_marginalia.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + #print(c_t_in_paragraph) + else: + + c_t_in_paragraph.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + #print(c_t_in_paragraph) + + break + else: + pass + + + if vv.tag==link+'Point': + if "type" in nn.attrib and nn.attrib['type']=='drop-capital': + #if nn.attrib['type']=='paragraph': + + c_t_in_drop.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + elif "type" in nn.attrib and nn.attrib['type']=='heading': + c_t_in_heading.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + + elif "type" in nn.attrib and nn.attrib['type']=='signature-mark': + + c_t_in_signature_mark.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + #print(c_t_in_paragraph) + sumi+=1 + elif "type" in nn.attrib and nn.attrib['type']=='header': + c_t_in_header.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + + elif "type" in nn.attrib and nn.attrib['type']=='catch-word': + c_t_in_catch.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + + elif "type" in nn.attrib and nn.attrib['type']=='page-number': + + c_t_in_page_number.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + #print(c_t_in_paragraph) + sumi+=1 + + elif "type" in nn.attrib and nn.attrib['type']=='marginalia': + + c_t_in_marginalia.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + #print(c_t_in_paragraph) + sumi+=1 + + else: + c_t_in_paragraph.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + #print(c_t_in_paragraph) + sumi+=1 + + #c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + + if len(c_t_in_drop)>0: + co_text_drop.append(np.array(c_t_in_drop)) + if len(c_t_in_paragraph)>0: + co_text_paragraph.append(np.array(c_t_in_paragraph)) + if len(c_t_in_heading)>0: + co_text_heading.append(np.array(c_t_in_heading)) + + if len(c_t_in_header)>0: + co_text_header.append(np.array(c_t_in_header)) + if len(c_t_in_page_number)>0: + co_text_page_number.append(np.array(c_t_in_page_number)) + if len(c_t_in_catch)>0: + co_text_catch.append(np.array(c_t_in_catch)) + + if len(c_t_in_signature_mark)>0: + co_text_signature_mark.append(np.array(c_t_in_signature_mark)) + + if len(c_t_in_marginalia)>0: + co_text_marginalia.append(np.array(c_t_in_marginalia)) + + + elif tag.endswith('}GraphicRegion') or tag.endswith('}graphicregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + c_t_in_text_annotation=[] + c_t_in_decoration=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + #c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + if "type" in nn.attrib and nn.attrib['type']=='handwritten-annotation': + #if nn.attrib['type']=='paragraph': + + c_t_in_text_annotation.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + elif "type" in nn.attrib and nn.attrib['type']=='decoration': + + c_t_in_decoration.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + #print(c_t_in_paragraph) + else: + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + + + break + else: + pass + + + if vv.tag==link+'Point': + + if "type" in nn.attrib and nn.attrib['type']=='handwritten-annotation': + #if nn.attrib['type']=='paragraph': + + c_t_in_text_annotation.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + elif "type" in nn.attrib and nn.attrib['type']=='decoration': + + c_t_in_decoration.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + #print(c_t_in_paragraph) + sumi+=1 + else: + c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + if len(c_t_in_text_annotation)>0: + co_graphic_text_annotation.append(np.array(c_t_in_text_annotation)) + if len(c_t_in_decoration)>0: + co_graphic_decoration.append(np.array(c_t_in_decoration)) + if len(c_t_in)>0: + co_graphic.append(np.array(c_t_in)) + + + + elif tag.endswith('}ImageRegion') or tag.endswith('}imageregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_img.append(np.array(c_t_in)) + + elif tag.endswith('}SeparatorRegion') or tag.endswith('}separatorregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_sep.append(np.array(c_t_in)) + + + + elif tag.endswith('}TableRegion') or tag.endswith('}tableregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_table.append(np.array(c_t_in)) + + elif tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_noise.append(np.array(c_t_in)) + + + img = np.zeros( (y_len,x_len,3) ) + + if self.output_type == '3d': + img_poly=cv2.fillPoly(img, pts =co_text_paragraph, color=(255,0,0)) + + img_poly=cv2.fillPoly(img, pts =co_text_heading, color=(255,125,0)) + img_poly=cv2.fillPoly(img, pts =co_text_header, color=(255,0,125)) + img_poly=cv2.fillPoly(img, pts =co_text_catch, color=(125,255,125)) + img_poly=cv2.fillPoly(img, pts =co_text_signature_mark, color=(125,125,0)) + img_poly=cv2.fillPoly(img, pts =co_graphic_decoration, color=(0,125,255)) + img_poly=cv2.fillPoly(img, pts =co_text_page_number, color=(0,125,0)) + img_poly=cv2.fillPoly(img, pts =co_text_marginalia, color=(125,125,125)) + img_poly=cv2.fillPoly(img, pts =co_text_drop, color=(0,125,255)) + + img_poly=cv2.fillPoly(img, pts =co_graphic_text_annotation, color=(125,0,125)) + img_poly=cv2.fillPoly(img, pts =co_img, color=(0,255,0)) + img_poly=cv2.fillPoly(img, pts =co_sep, color=(0,0,255)) + img_poly=cv2.fillPoly(img, pts =co_table, color=(0,255,255)) + img_poly=cv2.fillPoly(img, pts =co_graphic, color=(255,125,125)) + img_poly=cv2.fillPoly(img, pts =co_noise, color=(255,0,255)) + elif self.output_type == '2d': + img_poly=cv2.fillPoly(img, pts =co_text_paragraph, color=(1,1,1)) + + img_poly=cv2.fillPoly(img, pts =co_text_heading, color=(2,2,2)) + img_poly=cv2.fillPoly(img, pts =co_text_header, color=(2,2,2)) + img_poly=cv2.fillPoly(img, pts =co_text_catch, color=(3,3,3)) + img_poly=cv2.fillPoly(img, pts =co_text_signature_mark, color=(4,4,4)) + img_poly=cv2.fillPoly(img, pts =co_graphic_decoration, color=(5,5,5)) + img_poly=cv2.fillPoly(img, pts =co_text_page_number, color=(6,6,6)) + img_poly=cv2.fillPoly(img, pts =co_text_marginalia, color=(7,7,7)) + img_poly=cv2.fillPoly(img, pts =co_text_drop, color=(8,8,8)) + + img_poly=cv2.fillPoly(img, pts =co_graphic_text_annotation, color=(9,9,9)) + img_poly=cv2.fillPoly(img, pts =co_img, color=(10,10,10)) + img_poly=cv2.fillPoly(img, pts =co_sep, color=(11,11,11)) + img_poly=cv2.fillPoly(img, pts =co_table, color=(12,12,12)) + img_poly=cv2.fillPoly(img, pts =co_graphic, color=(13,13,14)) + img_poly=cv2.fillPoly(img, pts =co_noise, color=(15,15,15)) + + try: + cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('-')[1].split('.')[0]+'.png',img_poly ) + except: + cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.png',img_poly ) + + + elif self.experiment=='layout_for_main_regions_new_concept': + region_tags=np.unique([x for x in alltags if x.endswith('Region')]) + #print(region_tags) + co_text=[] + co_sep=[] + co_img=[] + co_drop = [] + co_graphic=[] + co_table = [] + + for tag in region_tags: + if tag.endswith('}TextRegion') or tag.endswith('}Textregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + c_t_in_drop = [] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + if "type" in nn.attrib and nn.attrib['type']=='drop-capital': + c_t_in_drop.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + else: + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + if "type" in nn.attrib and nn.attrib['type']=='drop-capital': + c_t_in_drop.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + else: + c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + if len(c_t_in)>0: + co_text.append(np.array(c_t_in)) + if len(c_t_in_drop)>0: + co_drop.append(np.array(c_t_in_drop)) + + elif tag.endswith('}ImageRegion') or tag.endswith('}GraphicRegion') or tag.endswith('}imageregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_img.append(np.array(c_t_in)) + + elif tag.endswith('}SeparatorRegion') or tag.endswith('}separatorregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_sep.append(np.array(c_t_in)) + + elif tag.endswith('}TableRegion') or tag.endswith('}tableregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_table.append(np.array(c_t_in)) + + img_boundary = np.zeros( (y_len,x_len) ) + + + co_text_eroded = [] + for con in co_text: + #try: + img_boundary_in = np.zeros( (y_len,x_len) ) + img_boundary_in = cv2.fillPoly(img_boundary_in, pts=[con], color=(1, 1, 1)) + #print('bidiahhhhaaa') + + + + #img_boundary_in = cv2.erode(img_boundary_in[:,:], KERNEL, iterations=7)#asiatica + img_boundary_in = cv2.erode(img_boundary_in[:,:], KERNEL, iterations=2) + + pixel = 1 + min_size = 0 + con_eroded = self.return_contours_of_interested_region(img_boundary_in,pixel, min_size ) + + try: + co_text_eroded.append(con_eroded[0]) + except: + co_text_eroded.append(con) + + img_boundary_in_dilated = cv2.dilate(img_boundary_in[:,:], KERNEL, iterations=4) + #img_boundary_in_dilated = cv2.dilate(img_boundary_in[:,:], KERNEL, iterations=5) + + boundary = img_boundary_in_dilated[:,:] - img_boundary_in[:,:] + + img_boundary[:,:][boundary[:,:]==1] =1 + + + ###co_table_eroded = [] + ###for con in co_table: + ####try: + ###img_boundary_in = np.zeros( (y_len,x_len) ) + ###img_boundary_in = cv2.fillPoly(img_boundary_in, pts=[con], color=(1, 1, 1)) + ####print('bidiahhhhaaa') + + + + #####img_boundary_in = cv2.erode(img_boundary_in[:,:], KERNEL, iterations=7)#asiatica + ###img_boundary_in = cv2.erode(img_boundary_in[:,:], KERNEL, iterations=2) + + ###pixel = 1 + ###min_size = 0 + ###con_eroded = self.return_contours_of_interested_region(img_boundary_in,pixel, min_size ) + + ###try: + ###co_table_eroded.append(con_eroded[0]) + ###except: + ###co_table_eroded.append(con) + + ###img_boundary_in_dilated = cv2.dilate(img_boundary_in[:,:], KERNEL, iterations=4) + + ###boundary = img_boundary_in_dilated[:,:] - img_boundary_in[:,:] + + ###img_boundary[:,:][boundary[:,:]==1] =1 + #except: + #pass + + #for con in co_img: + #img_boundary_in = np.zeros( (y_len,x_len) ) + #img_boundary_in = cv2.fillPoly(img_boundary_in, pts=[con], color=(1, 1, 1)) + #img_boundary_in_dilated = cv2.dilate(img_boundary_in[:,:], KERNEL, iterations=3) + + #boundary = img_boundary_in_dilated[:,:] - img_boundary_in[:,:] + + #img_boundary[:,:][boundary[:,:]==1] =1 + + + #for con in co_sep: + + #img_boundary_in = np.zeros( (y_len,x_len) ) + #img_boundary_in = cv2.fillPoly(img_boundary_in, pts=[con], color=(1, 1, 1)) + #img_boundary_in_dilated = cv2.dilate(img_boundary_in[:,:], KERNEL, iterations=3) + + #boundary = img_boundary_in_dilated[:,:] - img_boundary_in[:,:] + + img_boundary[:,:][boundary[:,:]==1] =1 + for con in co_drop: + img_boundary_in = np.zeros( (y_len,x_len) ) + img_boundary_in = cv2.fillPoly(img_boundary_in, pts=[con], color=(1, 1, 1)) + img_boundary_in_dilated = cv2.dilate(img_boundary_in[:,:], KERNEL, iterations=3) + + boundary = img_boundary_in_dilated[:,:] - img_boundary_in[:,:] + + img_boundary[:,:][boundary[:,:]==1] =1 + + + img = np.zeros( (y_len,x_len,3) ) + + if self.output_type == '2d': + img_poly=cv2.fillPoly(img, pts =co_img, color=(2,2,2)) + + img_poly=cv2.fillPoly(img, pts =co_text_eroded, color=(1,1,1)) + ##img_poly=cv2.fillPoly(img, pts =co_graphic, color=(4,4,4)) + ###img_poly=cv2.fillPoly(img, pts =co_table, color=(1,1,1)) + + img_poly=cv2.fillPoly(img, pts =co_drop, color=(1,1,1)) + img_poly[:,:][img_boundary[:,:]==1] = 4 + img_poly=cv2.fillPoly(img, pts =co_sep, color=(3,3,3)) + elif self.output_type == '3d': + img_poly=cv2.fillPoly(img, pts =co_img, color=(0,255,0)) + img_poly=cv2.fillPoly(img, pts =co_text_eroded, color=(255,0,0)) + img_poly=cv2.fillPoly(img, pts =co_drop, color=(0,125,255)) + + img_poly[:,:,0][img_boundary[:,:]==1]=255 + img_poly[:,:,1][img_boundary[:,:]==1]=125 + img_poly[:,:,2][img_boundary[:,:]==1]=125 + + img_poly=cv2.fillPoly(img, pts =co_sep, color=(0,0,255)) + ##img_poly=cv2.fillPoly(img, pts =co_graphic, color=(255,125,125)) + + #print('yazdimmm',self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.png') + try: + #print('yazdimmm',self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.png') + cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('-')[1].split('.')[0]+'.png',img_poly ) + except: + cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.png',img_poly ) + + + + #except: + #pass + def run(self): + self.get_content_of_dir() + self.get_images_of_ground_truth() + + +@click.command() +@click.option( + "--dir_xml", + "-dx", + help="directory of GT page-xml files", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--dir_out", + "-do", + help="directory where ground truth images would be written", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--type_output", + "-to", + help="this defines how output should be. A 2d image array or a 3d image array encoded with RGB color. Just pass 2d or 3d. The file will be saved one directory up. 2D image array is 3d but only information of one channel would be enough since all channels have the same values.", +) +@click.option( + "--experiment", + "-exp", + help="experiment of ineterst. Word , textline , glyph and textregion are desired options.", +) + +def main(dir_xml,dir_out,type_output,experiment): + x=pagexml2word(dir_xml,dir_out,type_output,experiment) + x.run() +if __name__=="__main__": + main() + + + diff --git a/train/requirements.txt b/train/requirements.txt index 3e56438..efee9df 100644 --- a/train/requirements.txt +++ b/train/requirements.txt @@ -7,3 +7,4 @@ imutils numpy scipy scikit-learn +shapely From faeac997e15c3dd824a029e8e798fc3e7a262a8c Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Fri, 17 May 2024 09:10:13 +0200 Subject: [PATCH 047/123] page to label enable textline new concept --- train/pagexml2label.py | 73 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/train/pagexml2label.py b/train/pagexml2label.py index 715f99f..b094e9b 100644 --- a/train/pagexml2label.py +++ b/train/pagexml2label.py @@ -217,6 +217,79 @@ class pagexml2word: except: cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.png',img_poly ) + elif self.experiment == 'textline_new_concept': + region_tags = np.unique([x for x in alltags if x.endswith('TextLine')]) + co_line = [] + + for tag in region_tags: + if tag.endswith('}TextLine') or tag.endswith('}textline'): + # print('sth') + for nn in root1.iter(tag): + c_t_in = [] + sumi = 0 + for vv in nn.iter(): + # check the format of coords + if vv.tag == link + 'Coords': + coords = bool(vv.attrib) + if coords: + p_h = vv.attrib['points'].split(' ') + c_t_in.append( + np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])) + break + else: + pass + + if vv.tag == link + 'Point': + c_t_in.append([int(np.float(vv.attrib['x'])), int(np.float(vv.attrib['y']))]) + sumi += 1 + # print(vv.tag,'in') + elif vv.tag != link + 'Point' and sumi >= 1: + break + co_line.append(np.array(c_t_in)) + + img_boundary = np.zeros((y_len, x_len)) + co_textline_eroded = [] + for con in co_line: + # try: + img_boundary_in = np.zeros((y_len, x_len)) + img_boundary_in = cv2.fillPoly(img_boundary_in, pts=[con], color=(1, 1, 1)) + # print('bidiahhhhaaa') + + # img_boundary_in = cv2.erode(img_boundary_in[:,:], KERNEL, iterations=7)#asiatica + img_boundary_in = cv2.erode(img_boundary_in[:, :], KERNEL, iterations=1) + + pixel = 1 + min_size = 0 + con_eroded = self.return_contours_of_interested_region(img_boundary_in, pixel, min_size) + + try: + co_textline_eroded.append(con_eroded[0]) + except: + co_textline_eroded.append(con) + + img_boundary_in_dilated = cv2.dilate(img_boundary_in[:, :], KERNEL, iterations=3) + # img_boundary_in_dilated = cv2.dilate(img_boundary_in[:,:], KERNEL, iterations=5) + + boundary = img_boundary_in_dilated[:, :] - img_boundary_in[:, :] + + img_boundary[:, :][boundary[:, :] == 1] = 1 + + img = np.zeros((y_len, x_len, 3)) + if self.output_type == '2d': + img_poly = cv2.fillPoly(img, pts=co_textline_eroded, color=(1, 1, 1)) + img_poly[:, :][img_boundary[:, :] == 1] = 2 + elif self.output_type == '3d': + img_poly = cv2.fillPoly(img, pts=co_textline_eroded, color=(255, 0, 0)) + img_poly[:, :, 0][img_boundary[:, :] == 1] = 255 + img_poly[:, :, 1][img_boundary[:, :] == 1] = 125 + img_poly[:, :, 2][img_boundary[:, :] == 1] = 125 + + try: + cv2.imwrite(self.output_dir + '/' + self.gt_list[index].split('-')[1].split('.')[0] + '.png', + img_poly) + except: + cv2.imwrite(self.output_dir + '/' + self.gt_list[index].split('.')[0] + '.png', img_poly) + elif self.experiment=='layout_for_main_regions': region_tags=np.unique([x for x in alltags if x.endswith('Region')]) #print(region_tags) From b2085a1d01ec6a501a6f0752f492ab71f3015723 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Fri, 17 May 2024 09:08:25 +0200 Subject: [PATCH 048/123] update requirements --- train/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/train/requirements.txt b/train/requirements.txt index efee9df..d8f9003 100644 --- a/train/requirements.txt +++ b/train/requirements.txt @@ -8,3 +8,4 @@ numpy scipy scikit-learn shapely +click From f1c2913c0394dbb64a5464afc183d3600a222f6b Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Wed, 22 May 2024 12:38:24 +0200 Subject: [PATCH 049/123] page2label with a dynamic layout --- train/custom_config_page2label.json | 6 + train/pagexml2label.py | 490 +++++++++++++++++++++++++++- 2 files changed, 479 insertions(+), 17 deletions(-) create mode 100644 train/custom_config_page2label.json diff --git a/train/custom_config_page2label.json b/train/custom_config_page2label.json new file mode 100644 index 0000000..75c4b96 --- /dev/null +++ b/train/custom_config_page2label.json @@ -0,0 +1,6 @@ +{ +"textregions":{"paragraph":1, "heading": 2, "header":2,"drop-capital": 3, "marginal":4 }, +"imageregion":5, +"separatorregion":6, +"graphicregions" :{"handwritten-annotation":7, "decoration": 8, "signature": 9, "stamp": 10} +} diff --git a/train/pagexml2label.py b/train/pagexml2label.py index b094e9b..6907e84 100644 --- a/train/pagexml2label.py +++ b/train/pagexml2label.py @@ -7,6 +7,7 @@ import xml.etree.ElementTree as ET from tqdm import tqdm import cv2 from shapely import geometry +import json with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -21,11 +22,12 @@ This classes.txt file is required for dhsegment tool. KERNEL = np.ones((5, 5), np.uint8) class pagexml2word: - def __init__(self,dir_in, out_dir,output_type,experiment): + def __init__(self,dir_in, out_dir,output_type,experiment,layout_config): self.dir=dir_in self.output_dir=out_dir self.output_type=output_type self.experiment=experiment + self.layout_config=layout_config def get_content_of_dir(self): """ @@ -77,7 +79,7 @@ class pagexml2word: return contours_imgs - def get_images_of_ground_truth(self): + def get_images_of_ground_truth(self, config_params): """ Reading the page xml files and write the ground truth images into given output directory. """ @@ -93,6 +95,445 @@ class pagexml2word: for jj in root1.iter(link+'Page'): y_len=int(jj.attrib['imageHeight']) x_len=int(jj.attrib['imageWidth']) + + if self.layout_config: + keys = list(config_params.keys()) + #values = config_params.values() + + if 'textregions' in keys: + types_text_dict = config_params['textregions'] + types_text = list(types_text_dict.keys()) + types_text_label = list(types_text_dict.values()) + if 'graphicregions' in keys: + types_graphic_dict = config_params['graphicregions'] + types_graphic = list(types_graphic_dict.keys()) + types_graphic_label = list(types_graphic_dict.values()) + + + types_text_label_rgb = [ (0,0,0), (255,0,0), (255,125,0), (255,0,125), (125,255,125), (125,125,0), (0,125,255), (0,125,0), (125,125,125), (0,125,255), (125,0,125), (0,255,0),(0,0,255), (0,255,255), (255,125,125), (0,125,255), (0,255,125)] + + region_tags=np.unique([x for x in alltags if x.endswith('Region')]) + + co_text_paragraph=[] + co_text_drop=[] + co_text_heading=[] + co_text_header=[] + co_text_marginalia=[] + co_text_catch=[] + co_text_page_number=[] + co_text_signature_mark=[] + co_sep=[] + co_img=[] + co_table=[] + co_graphic_signature=[] + co_graphic_text_annotation=[] + co_graphic_decoration=[] + co_graphic_stamp=[] + co_noise=[] + + for tag in region_tags: + if 'textregions' in keys: + if tag.endswith('}TextRegion') or tag.endswith('}Textregion'): + for nn in root1.iter(tag): + c_t_in_drop=[] + c_t_in_paragraph=[] + c_t_in_heading=[] + c_t_in_header=[] + c_t_in_page_number=[] + c_t_in_signature_mark=[] + c_t_in_catch=[] + c_t_in_marginalia=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + + coords=bool(vv.attrib) + if coords: + #print('birda1') + p_h=vv.attrib['points'].split(' ') + + if "drop-capital" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='drop-capital': + c_t_in_drop.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + if "heading" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='heading': + c_t_in_heading.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + if "signature-mark" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='signature-mark': + c_t_in_signature_mark.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + if "header" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='header': + c_t_in_header.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + if "catch-word" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='catch-word': + c_t_in_catch.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + if "page-number" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='page-number': + c_t_in_page_number.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + if "marginalia" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='marginalia': + c_t_in_marginalia.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + if "paragraph" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='paragraph': + c_t_in_paragraph.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + + break + else: + pass + + + if vv.tag==link+'Point': + if "drop-capital" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='drop-capital': + c_t_in_drop.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + if "heading" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='heading': + c_t_in_heading.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + if "signature-mark" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='signature-mark': + c_t_in_signature_mark.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + if "header" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='header': + c_t_in_header.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + if "catch-word" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='catch-word': + c_t_in_catch.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + if "page-number" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='page-number': + c_t_in_page_number.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + if "marginalia" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='marginalia': + c_t_in_marginalia.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + if "paragraph" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='paragraph': + c_t_in_paragraph.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + + elif vv.tag!=link+'Point' and sumi>=1: + break + + if len(c_t_in_drop)>0: + co_text_drop.append(np.array(c_t_in_drop)) + if len(c_t_in_paragraph)>0: + co_text_paragraph.append(np.array(c_t_in_paragraph)) + if len(c_t_in_heading)>0: + co_text_heading.append(np.array(c_t_in_heading)) + + if len(c_t_in_header)>0: + co_text_header.append(np.array(c_t_in_header)) + if len(c_t_in_page_number)>0: + co_text_page_number.append(np.array(c_t_in_page_number)) + if len(c_t_in_catch)>0: + co_text_catch.append(np.array(c_t_in_catch)) + + if len(c_t_in_signature_mark)>0: + co_text_signature_mark.append(np.array(c_t_in_signature_mark)) + + if len(c_t_in_marginalia)>0: + co_text_marginalia.append(np.array(c_t_in_marginalia)) + + + if 'graphicregions' in keys: + if tag.endswith('}GraphicRegion') or tag.endswith('}graphicregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in_stamp=[] + c_t_in_text_annotation=[] + c_t_in_decoration=[] + c_t_in_signature=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + if "handwritten-annotation" in types_graphic: + if "type" in nn.attrib and nn.attrib['type']=='handwritten-annotation': + c_t_in_text_annotation.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + if "decoration" in types_graphic: + if "type" in nn.attrib and nn.attrib['type']=='decoration': + c_t_in_decoration.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + if "stamp" in types_graphic: + if "type" in nn.attrib and nn.attrib['type']=='stamp': + c_t_in_stamp.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + if "signature" in types_graphic: + if "type" in nn.attrib and nn.attrib['type']=='signature': + c_t_in_signature.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + + + break + else: + pass + + + if vv.tag==link+'Point': + if "handwritten-annotation" in types_graphic: + if "type" in nn.attrib and nn.attrib['type']=='handwritten-annotation': + c_t_in_text_annotation.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + if "decoration" in types_graphic: + if "type" in nn.attrib and nn.attrib['type']=='decoration': + c_t_in_decoration.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + if "stamp" in types_graphic: + if "type" in nn.attrib and nn.attrib['type']=='stamp': + c_t_in_stamp.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + if "signature" in types_graphic: + if "type" in nn.attrib and nn.attrib['type']=='signature': + c_t_in_signature.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + if len(c_t_in_text_annotation)>0: + co_graphic_text_annotation.append(np.array(c_t_in_text_annotation)) + if len(c_t_in_decoration)>0: + co_graphic_decoration.append(np.array(c_t_in_decoration)) + if len(c_t_in_stamp)>0: + co_graphic_stamp.append(np.array(c_t_in_stamp)) + if len(c_t_in_signature)>0: + co_graphic_signature.append(np.array(c_t_in_signature)) + + if 'imageregion' in keys: + if tag.endswith('}ImageRegion') or tag.endswith('}imageregion'): + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + elif vv.tag!=link+'Point' and sumi>=1: + break + co_img.append(np.array(c_t_in)) + + + if 'separatorregion' in keys: + if tag.endswith('}SeparatorRegion') or tag.endswith('}separatorregion'): + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + elif vv.tag!=link+'Point' and sumi>=1: + break + co_sep.append(np.array(c_t_in)) + + + + if 'tableregion' in keys: + if tag.endswith('}TableRegion') or tag.endswith('}tableregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_table.append(np.array(c_t_in)) + + if 'noiseregion' in keys: + if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_noise.append(np.array(c_t_in)) + + img = np.zeros( (y_len,x_len,3) ) + + if self.output_type == '3d': + + if 'graphicregions' in keys: + if "handwritten-annotation" in types_graphic: + img_poly=cv2.fillPoly(img, pts =co_graphic_text_annotation, color=types_text_label_rgb[ config_params['graphicregions']['handwritten-annotation']]) + if "signature" in types_graphic: + img_poly=cv2.fillPoly(img, pts =co_graphic_signature, color=types_text_label_rgb[ config_params['graphicregions']['signature']]) + if "decoration" in types_graphic: + img_poly=cv2.fillPoly(img, pts =co_graphic_decoration, color=types_text_label_rgb[ config_params['graphicregions']['decoration']]) + if "stamp" in types_graphic: + img_poly=cv2.fillPoly(img, pts =co_graphic_stamp, color=types_text_label_rgb[ config_params['graphicregions']['stamp']]) + + if 'imageregion' in keys: + img_poly=cv2.fillPoly(img, pts =co_img, color=types_text_label_rgb[ config_params['imageregion']]) + if 'separatorregion' in keys: + img_poly=cv2.fillPoly(img, pts =co_sep, color=types_text_label_rgb[ config_params['separatorregion']]) + if 'tableregion' in keys: + img_poly=cv2.fillPoly(img, pts =co_table, color=types_text_label_rgb[ config_params['tableregion']]) + if 'noiseregion' in keys: + img_poly=cv2.fillPoly(img, pts =co_noise, color=types_text_label_rgb[ config_params['noiseregion']]) + + if 'textregions' in keys: + if "paragraph" in types_text: + img_poly=cv2.fillPoly(img, pts =co_text_paragraph, color=types_text_label_rgb[ config_params['textregions']['paragraph']]) + if "heading" in types_text: + img_poly=cv2.fillPoly(img, pts =co_text_heading, color=types_text_label_rgb[ config_params['textregions']['heading']]) + if "header" in types_text: + img_poly=cv2.fillPoly(img, pts =co_text_header, color=types_text_label_rgb[ config_params['textregions']['header']]) + if "catch-word" in types_text: + img_poly=cv2.fillPoly(img, pts =co_text_catch, color=types_text_label_rgb[ config_params['textregions']['catch-word']]) + if "signature-mark" in types_text: + img_poly=cv2.fillPoly(img, pts =co_text_signature_mark, color=types_text_label_rgb[ config_params['textregions']['signature-mark']]) + if "page-number" in types_text: + img_poly=cv2.fillPoly(img, pts =co_text_page_number, color=types_text_label_rgb[ config_params['textregions']['page-number']]) + if "marginalia" in types_text: + img_poly=cv2.fillPoly(img, pts =co_text_marginalia, color=types_text_label_rgb[ config_params['textregions']['marginalia']]) + if "drop-capital" in types_text: + img_poly=cv2.fillPoly(img, pts =co_text_drop, color=types_text_label_rgb[ config_params['textregions']['drop-capital']]) + + elif self.output_type == '2d': + if 'graphicregions' in keys: + if "handwritten-annotation" in types_graphic: + color_label = config_params['graphicregions']['handwritten-annotation'] + img_poly=cv2.fillPoly(img, pts =co_graphic_text_annotation, color=(color_label,color_label,color_label)) + if "signature" in types_graphic: + color_label = config_params['graphicregions']['signature'] + img_poly=cv2.fillPoly(img, pts =co_graphic_signature, color=(color_label,color_label,color_label)) + if "decoration" in types_graphic: + color_label = config_params['graphicregions']['decoration'] + img_poly=cv2.fillPoly(img, pts =co_graphic_decoration, color=(color_label,color_label,color_label)) + if "stamp" in types_graphic: + color_label = config_params['graphicregions']['stamp'] + img_poly=cv2.fillPoly(img, pts =co_graphic_stamp, color=(color_label,color_label,color_label)) + + if 'imageregion' in keys: + color_label = config_params['imageregion'] + img_poly=cv2.fillPoly(img, pts =co_img, color=(color_label,color_label,color_label)) + if 'separatorregion' in keys: + color_label = config_params['separatorregion'] + img_poly=cv2.fillPoly(img, pts =co_sep, color=(color_label,color_label,color_label)) + if 'tableregion' in keys: + color_label = config_params['tableregion'] + img_poly=cv2.fillPoly(img, pts =co_table, color=(color_label,color_label,color_label)) + if 'noiseregion' in keys: + color_label = config_params['noiseregion'] + img_poly=cv2.fillPoly(img, pts =co_noise, color=(color_label,color_label,color_label)) + + if 'textregions' in keys: + if "paragraph" in types_text: + color_label = config_params['textregions']['paragraph'] + img_poly=cv2.fillPoly(img, pts =co_text_paragraph, color=(color_label,color_label,color_label)) + if "heading" in types_text: + color_label = config_params['textregions']['heading'] + img_poly=cv2.fillPoly(img, pts =co_text_heading, color=(color_label,color_label,color_label)) + if "header" in types_text: + color_label = config_params['textregions']['header'] + img_poly=cv2.fillPoly(img, pts =co_text_header, color=(color_label,color_label,color_label)) + if "catch-word" in types_text: + color_label = config_params['textregions']['catch-word'] + img_poly=cv2.fillPoly(img, pts =co_text_catch, color=(color_label,color_label,color_label)) + if "signature-mark" in types_text: + color_label = config_params['textregions']['signature-mark'] + img_poly=cv2.fillPoly(img, pts =co_text_signature_mark, color=(color_label,color_label,color_label)) + if "page-number" in types_text: + color_label = config_params['textregions']['page-number'] + img_poly=cv2.fillPoly(img, pts =co_text_page_number, color=(color_label,color_label,color_label)) + if "marginalia" in types_text: + color_label = config_params['textregions']['marginalia'] + img_poly=cv2.fillPoly(img, pts =co_text_marginalia, color=(color_label,color_label,color_label)) + if "drop-capital" in types_text: + color_label = config_params['textregions']['drop-capital'] + img_poly=cv2.fillPoly(img, pts =co_text_drop, color=(color_label,color_label,color_label)) + + + + + try: + cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('-')[1].split('.')[0]+'.png',img_poly ) + except: + cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.png',img_poly ) + + + #print(values[0]) if self.experiment=='word': region_tags=np.unique([x for x in alltags if x.endswith('Word')]) co_word=[] @@ -302,6 +743,7 @@ class pagexml2word: if tag.endswith('}TextRegion') or tag.endswith('}Textregion'): #print('sth') for nn in root1.iter(tag): + print(nn.attrib['type']) c_t_in=[] sumi=0 for vv in nn.iter(): @@ -373,20 +815,19 @@ class pagexml2word: elif vv.tag!=link+'Point' and sumi>=1: break co_sep.append(np.array(c_t_in)) - - - img = np.zeros( (y_len,x_len,3) ) + img_poly = np.zeros( (y_len,x_len,3) ) + if self.output_type == '3d': - img_poly=cv2.fillPoly(img, pts =co_text, color=(255,0,0)) - img_poly=cv2.fillPoly(img, pts =co_img, color=(0,255,0)) - img_poly=cv2.fillPoly(img, pts =co_sep, color=(0,0,255)) + img_poly=cv2.fillPoly(img_poly, pts =co_text, color=(255,0,0)) + img_poly=cv2.fillPoly(img_poly, pts =co_img, color=(0,255,0)) + img_poly=cv2.fillPoly(img_poly, pts =co_sep, color=(0,0,255)) ##img_poly=cv2.fillPoly(img, pts =co_graphic, color=(255,125,125)) elif self.output_type == '2d': - img_poly=cv2.fillPoly(img, pts =co_text, color=(1,1,1)) - img_poly=cv2.fillPoly(img, pts =co_img, color=(2,2,2)) - img_poly=cv2.fillPoly(img, pts =co_sep, color=(3,3,3)) + img_poly=cv2.fillPoly(img_poly, pts =co_text, color=(1,1,1)) + img_poly=cv2.fillPoly(img_poly, pts =co_img, color=(2,2,2)) + img_poly=cv2.fillPoly(img_poly, pts =co_sep, color=(3,3,3)) try: cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('-')[1].split('.')[0]+'.png',img_poly ) @@ -752,7 +1193,7 @@ class pagexml2word: img = np.zeros( (y_len,x_len,3) ) - + if self.output_type == '3d': img_poly=cv2.fillPoly(img, pts =co_text_paragraph, color=(255,0,0)) @@ -1043,9 +1484,9 @@ class pagexml2word: #except: #pass - def run(self): + def run(self,config_params): self.get_content_of_dir() - self.get_images_of_ground_truth() + self.get_images_of_ground_truth(config_params) @click.command() @@ -1061,6 +1502,14 @@ class pagexml2word: help="directory where ground truth images would be written", type=click.Path(exists=True, file_okay=False), ) + +@click.option( + "--layout_config", + "-lc", + help="experiment of ineterst. Word , textline , glyph and textregion are desired options.", + type=click.Path(exists=True, dir_okay=False), +) + @click.option( "--type_output", "-to", @@ -1072,9 +1521,16 @@ class pagexml2word: help="experiment of ineterst. Word , textline , glyph and textregion are desired options.", ) -def main(dir_xml,dir_out,type_output,experiment): - x=pagexml2word(dir_xml,dir_out,type_output,experiment) - x.run() + +def main(dir_xml,dir_out,type_output,experiment,layout_config): + if layout_config: + with open(layout_config) as f: + config_params = json.load(f) + else: + print("passed") + config_params = None + x=pagexml2word(dir_xml,dir_out,type_output,experiment, layout_config) + x.run(config_params) if __name__=="__main__": main() From 47c6bf6b97db0e8ea9eb3e796cf9261ddaa2e4db Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Thu, 23 May 2024 11:14:14 +0200 Subject: [PATCH 050/123] dynamic layout decorated with artificial class on text elements boundry --- train/custom_config_page2label.json | 6 +- train/pagexml2label.py | 117 +++++++++++++++++++++++----- 2 files changed, 103 insertions(+), 20 deletions(-) diff --git a/train/custom_config_page2label.json b/train/custom_config_page2label.json index 75c4b96..85b5d7e 100644 --- a/train/custom_config_page2label.json +++ b/train/custom_config_page2label.json @@ -1,6 +1,8 @@ { -"textregions":{"paragraph":1, "heading": 2, "header":2,"drop-capital": 3, "marginal":4 }, +"textregions":{"paragraph":1, "heading": 2, "header":2,"drop-capital": 3, "marginalia":4 ,"page-number":1 , "catch-word":1 }, "imageregion":5, "separatorregion":6, -"graphicregions" :{"handwritten-annotation":7, "decoration": 8, "signature": 9, "stamp": 10} +"graphicregions" :{"handwritten-annotation":7, "decoration": 8, "signature": 9, "stamp": 10}, +"artificial_class_on_boundry": ["paragraph","header", "heading", "marginalia", "page-number", "catch-word", "drop-capital"], +"artificial_class_label":11 } diff --git a/train/pagexml2label.py b/train/pagexml2label.py index 6907e84..5311c24 100644 --- a/train/pagexml2label.py +++ b/train/pagexml2label.py @@ -78,7 +78,37 @@ class pagexml2word: contours_imgs = self.filter_contours_area_of_image_tables(thresh, contours_imgs, hierarchy, max_area=1, min_area=min_area) return contours_imgs + def update_region_contours(self, co_text, img_boundary, erosion_rate, dilation_rate, y_len, x_len): + co_text_eroded = [] + for con in co_text: + #try: + img_boundary_in = np.zeros( (y_len,x_len) ) + img_boundary_in = cv2.fillPoly(img_boundary_in, pts=[con], color=(1, 1, 1)) + #print('bidiahhhhaaa') + + + + #img_boundary_in = cv2.erode(img_boundary_in[:,:], KERNEL, iterations=7)#asiatica + if erosion_rate > 0: + img_boundary_in = cv2.erode(img_boundary_in[:,:], KERNEL, iterations=erosion_rate) + + pixel = 1 + min_size = 0 + con_eroded = self.return_contours_of_interested_region(img_boundary_in,pixel, min_size ) + + try: + co_text_eroded.append(con_eroded[0]) + except: + co_text_eroded.append(con) + + img_boundary_in_dilated = cv2.dilate(img_boundary_in[:,:], KERNEL, iterations=dilation_rate) + #img_boundary_in_dilated = cv2.dilate(img_boundary_in[:,:], KERNEL, iterations=5) + + boundary = img_boundary_in_dilated[:,:] - img_boundary_in[:,:] + + img_boundary[:,:][boundary[:,:]==1] =1 + return co_text_eroded, img_boundary def get_images_of_ground_truth(self, config_params): """ Reading the page xml files and write the ground truth images into given output directory. @@ -98,6 +128,10 @@ class pagexml2word: if self.layout_config: keys = list(config_params.keys()) + if "artificial_class_on_boundry" in keys: + elements_with_artificial_class = list(config_params['artificial_class_on_boundry']) + artificial_class_rgb_color = (255,255,0) + artificial_class_label = config_params['artificial_class_label'] #values = config_params.values() if 'textregions' in keys: @@ -110,7 +144,7 @@ class pagexml2word: types_graphic_label = list(types_graphic_dict.values()) - types_text_label_rgb = [ (0,0,0), (255,0,0), (255,125,0), (255,0,125), (125,255,125), (125,125,0), (0,125,255), (0,125,0), (125,125,125), (0,125,255), (125,0,125), (0,255,0),(0,0,255), (0,255,255), (255,125,125), (0,125,255), (0,255,125)] + labels_rgb_color = [ (0,0,0), (255,0,0), (255,125,0), (255,0,125), (125,255,125), (125,125,0), (0,125,255), (0,125,0), (125,125,125), (255,0,255), (125,0,125), (0,255,0),(0,0,255), (0,255,255), (255,125,125), (0,125,125), (0,255,125)] region_tags=np.unique([x for x in alltags if x.endswith('Region')]) @@ -429,46 +463,90 @@ class pagexml2word: break co_noise.append(np.array(c_t_in)) + if "artificial_class_on_boundry" in keys: + img_boundary = np.zeros( (y_len,x_len) ) + if "paragraph" in elements_with_artificial_class: + erosion_rate = 2 + dilation_rate = 4 + co_text_paragraph, img_boundary = self.update_region_contours(co_text_paragraph, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "drop-capital" in elements_with_artificial_class: + erosion_rate = 0 + dilation_rate = 4 + co_text_drop, img_boundary = self.update_region_contours(co_text_drop, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "catch-word" in elements_with_artificial_class: + erosion_rate = 0 + dilation_rate = 4 + co_text_catch, img_boundary = self.update_region_contours(co_text_catch, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "page-number" in elements_with_artificial_class: + erosion_rate = 0 + dilation_rate = 4 + co_text_page_number, img_boundary = self.update_region_contours(co_text_page_number, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "header" in elements_with_artificial_class: + erosion_rate = 1 + dilation_rate = 4 + co_text_header, img_boundary = self.update_region_contours(co_text_header, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "heading" in elements_with_artificial_class: + erosion_rate = 1 + dilation_rate = 4 + co_text_heading, img_boundary = self.update_region_contours(co_text_heading, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "signature-mark" in elements_with_artificial_class: + erosion_rate = 1 + dilation_rate = 4 + co_text_signature_mark, img_boundary = self.update_region_contours(co_text_signature_mark, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "marginalia" in elements_with_artificial_class: + erosion_rate = 2 + dilation_rate = 4 + co_text_marginalia, img_boundary = self.update_region_contours(co_text_marginalia, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + + img = np.zeros( (y_len,x_len,3) ) if self.output_type == '3d': if 'graphicregions' in keys: if "handwritten-annotation" in types_graphic: - img_poly=cv2.fillPoly(img, pts =co_graphic_text_annotation, color=types_text_label_rgb[ config_params['graphicregions']['handwritten-annotation']]) + img_poly=cv2.fillPoly(img, pts =co_graphic_text_annotation, color=labels_rgb_color[ config_params['graphicregions']['handwritten-annotation']]) if "signature" in types_graphic: - img_poly=cv2.fillPoly(img, pts =co_graphic_signature, color=types_text_label_rgb[ config_params['graphicregions']['signature']]) + img_poly=cv2.fillPoly(img, pts =co_graphic_signature, color=labels_rgb_color[ config_params['graphicregions']['signature']]) if "decoration" in types_graphic: - img_poly=cv2.fillPoly(img, pts =co_graphic_decoration, color=types_text_label_rgb[ config_params['graphicregions']['decoration']]) + img_poly=cv2.fillPoly(img, pts =co_graphic_decoration, color=labels_rgb_color[ config_params['graphicregions']['decoration']]) if "stamp" in types_graphic: - img_poly=cv2.fillPoly(img, pts =co_graphic_stamp, color=types_text_label_rgb[ config_params['graphicregions']['stamp']]) + img_poly=cv2.fillPoly(img, pts =co_graphic_stamp, color=labels_rgb_color[ config_params['graphicregions']['stamp']]) if 'imageregion' in keys: - img_poly=cv2.fillPoly(img, pts =co_img, color=types_text_label_rgb[ config_params['imageregion']]) + img_poly=cv2.fillPoly(img, pts =co_img, color=labels_rgb_color[ config_params['imageregion']]) if 'separatorregion' in keys: - img_poly=cv2.fillPoly(img, pts =co_sep, color=types_text_label_rgb[ config_params['separatorregion']]) + img_poly=cv2.fillPoly(img, pts =co_sep, color=labels_rgb_color[ config_params['separatorregion']]) if 'tableregion' in keys: - img_poly=cv2.fillPoly(img, pts =co_table, color=types_text_label_rgb[ config_params['tableregion']]) + img_poly=cv2.fillPoly(img, pts =co_table, color=labels_rgb_color[ config_params['tableregion']]) if 'noiseregion' in keys: - img_poly=cv2.fillPoly(img, pts =co_noise, color=types_text_label_rgb[ config_params['noiseregion']]) + img_poly=cv2.fillPoly(img, pts =co_noise, color=labels_rgb_color[ config_params['noiseregion']]) if 'textregions' in keys: if "paragraph" in types_text: - img_poly=cv2.fillPoly(img, pts =co_text_paragraph, color=types_text_label_rgb[ config_params['textregions']['paragraph']]) + img_poly=cv2.fillPoly(img, pts =co_text_paragraph, color=labels_rgb_color[ config_params['textregions']['paragraph']]) if "heading" in types_text: - img_poly=cv2.fillPoly(img, pts =co_text_heading, color=types_text_label_rgb[ config_params['textregions']['heading']]) + img_poly=cv2.fillPoly(img, pts =co_text_heading, color=labels_rgb_color[ config_params['textregions']['heading']]) if "header" in types_text: - img_poly=cv2.fillPoly(img, pts =co_text_header, color=types_text_label_rgb[ config_params['textregions']['header']]) + img_poly=cv2.fillPoly(img, pts =co_text_header, color=labels_rgb_color[ config_params['textregions']['header']]) if "catch-word" in types_text: - img_poly=cv2.fillPoly(img, pts =co_text_catch, color=types_text_label_rgb[ config_params['textregions']['catch-word']]) + img_poly=cv2.fillPoly(img, pts =co_text_catch, color=labels_rgb_color[ config_params['textregions']['catch-word']]) if "signature-mark" in types_text: - img_poly=cv2.fillPoly(img, pts =co_text_signature_mark, color=types_text_label_rgb[ config_params['textregions']['signature-mark']]) + img_poly=cv2.fillPoly(img, pts =co_text_signature_mark, color=labels_rgb_color[ config_params['textregions']['signature-mark']]) if "page-number" in types_text: - img_poly=cv2.fillPoly(img, pts =co_text_page_number, color=types_text_label_rgb[ config_params['textregions']['page-number']]) + img_poly=cv2.fillPoly(img, pts =co_text_page_number, color=labels_rgb_color[ config_params['textregions']['page-number']]) if "marginalia" in types_text: - img_poly=cv2.fillPoly(img, pts =co_text_marginalia, color=types_text_label_rgb[ config_params['textregions']['marginalia']]) + img_poly=cv2.fillPoly(img, pts =co_text_marginalia, color=labels_rgb_color[ config_params['textregions']['marginalia']]) if "drop-capital" in types_text: - img_poly=cv2.fillPoly(img, pts =co_text_drop, color=types_text_label_rgb[ config_params['textregions']['drop-capital']]) + img_poly=cv2.fillPoly(img, pts =co_text_drop, color=labels_rgb_color[ config_params['textregions']['drop-capital']]) + + if "artificial_class_on_boundry" in keys: + img_poly[:,:,0][img_boundary[:,:]==1] = artificial_class_rgb_color[0] + img_poly[:,:,1][img_boundary[:,:]==1] = artificial_class_rgb_color[1] + img_poly[:,:,2][img_boundary[:,:]==1] = artificial_class_rgb_color[2] + + + elif self.output_type == '2d': if 'graphicregions' in keys: @@ -523,6 +601,9 @@ class pagexml2word: if "drop-capital" in types_text: color_label = config_params['textregions']['drop-capital'] img_poly=cv2.fillPoly(img, pts =co_text_drop, color=(color_label,color_label,color_label)) + + if "artificial_class_on_boundry" in keys: + img_poly[:,:][img_boundary[:,:]==1] = artificial_class_label @@ -1506,7 +1587,7 @@ class pagexml2word: @click.option( "--layout_config", "-lc", - help="experiment of ineterst. Word , textline , glyph and textregion are desired options.", + help="config file of prefered layout.", type=click.Path(exists=True, dir_okay=False), ) From 348d323c7cd98c53bfdbde37c517c5217db14f11 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Thu, 23 May 2024 15:43:31 +0200 Subject: [PATCH 051/123] missing text types are added --- train/custom_config_page2label.json | 12 ++++---- train/pagexml2label.py | 48 ++++++++++++++++++++++++++++- 2 files changed, 53 insertions(+), 7 deletions(-) diff --git a/train/custom_config_page2label.json b/train/custom_config_page2label.json index 85b5d7e..254f4df 100644 --- a/train/custom_config_page2label.json +++ b/train/custom_config_page2label.json @@ -1,8 +1,8 @@ { -"textregions":{"paragraph":1, "heading": 2, "header":2,"drop-capital": 3, "marginalia":4 ,"page-number":1 , "catch-word":1 }, -"imageregion":5, -"separatorregion":6, -"graphicregions" :{"handwritten-annotation":7, "decoration": 8, "signature": 9, "stamp": 10}, -"artificial_class_on_boundry": ["paragraph","header", "heading", "marginalia", "page-number", "catch-word", "drop-capital"], -"artificial_class_label":11 +"textregions":{"paragraph":1, "heading": 1, "header":1,"drop-capital": 1, "marginalia":1 ,"page-number":1 , "catch-word":1 ,"footnote": 1, "footnote-continued": 1}, +"imageregion":2, +"separatorregion":3, +"graphicregions" :{"handwritten-annotation":2, "decoration": 2, "signature": 2, "stamp": 2}, +"artificial_class_on_boundry": ["paragraph","header", "heading", "marginalia", "page-number", "catch-word", "drop-capital","footnote", "footnote-continued"], +"artificial_class_label":4 } diff --git a/train/pagexml2label.py b/train/pagexml2label.py index 5311c24..63b7acf 100644 --- a/train/pagexml2label.py +++ b/train/pagexml2label.py @@ -113,6 +113,7 @@ class pagexml2word: """ Reading the page xml files and write the ground truth images into given output directory. """ + ## to do: add footnote to text regions for index in tqdm(range(len(self.gt_list))): #try: tree1 = ET.parse(self.dir+'/'+self.gt_list[index]) @@ -144,11 +145,13 @@ class pagexml2word: types_graphic_label = list(types_graphic_dict.values()) - labels_rgb_color = [ (0,0,0), (255,0,0), (255,125,0), (255,0,125), (125,255,125), (125,125,0), (0,125,255), (0,125,0), (125,125,125), (255,0,255), (125,0,125), (0,255,0),(0,0,255), (0,255,255), (255,125,125), (0,125,125), (0,255,125)] + labels_rgb_color = [ (0,0,0), (255,0,0), (255,125,0), (255,0,125), (125,255,125), (125,125,0), (0,125,255), (0,125,0), (125,125,125), (255,0,255), (125,0,125), (0,255,0),(0,0,255), (0,255,255), (255,125,125), (0,125,125), (0,255,125), (255,125,255), (125,255,0)] region_tags=np.unique([x for x in alltags if x.endswith('Region')]) co_text_paragraph=[] + co_text_footnote=[] + co_text_footnote_con=[] co_text_drop=[] co_text_heading=[] co_text_header=[] @@ -177,6 +180,8 @@ class pagexml2word: c_t_in_signature_mark=[] c_t_in_catch=[] c_t_in_marginalia=[] + c_t_in_footnote=[] + c_t_in_footnote_con=[] sumi=0 for vv in nn.iter(): # check the format of coords @@ -190,6 +195,14 @@ class pagexml2word: if "drop-capital" in types_text: if "type" in nn.attrib and nn.attrib['type']=='drop-capital': c_t_in_drop.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + if "footnote" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='footnote': + c_t_in_footnote.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + if "footnote-continued" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='footnote-continued': + c_t_in_footnote_con.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) if "heading" in types_text: if "type" in nn.attrib and nn.attrib['type']=='heading': @@ -231,6 +244,16 @@ class pagexml2word: c_t_in_drop.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) sumi+=1 + if "footnote" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='footnote': + c_t_in_footnote.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + if "footnote-continued" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='footnote-continued': + c_t_in_footnote_con.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + if "heading" in types_text: if "type" in nn.attrib and nn.attrib['type']=='heading': c_t_in_heading.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) @@ -272,6 +295,10 @@ class pagexml2word: if len(c_t_in_drop)>0: co_text_drop.append(np.array(c_t_in_drop)) + if len(c_t_in_footnote_con)>0: + co_text_footnote_con.append(np.array(c_t_in_footnote_con)) + if len(c_t_in_footnote)>0: + co_text_footnote.append(np.array(c_t_in_footnote)) if len(c_t_in_paragraph)>0: co_text_paragraph.append(np.array(c_t_in_paragraph)) if len(c_t_in_heading)>0: @@ -497,6 +524,15 @@ class pagexml2word: erosion_rate = 2 dilation_rate = 4 co_text_marginalia, img_boundary = self.update_region_contours(co_text_marginalia, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "footnote" in elements_with_artificial_class: + erosion_rate = 2 + dilation_rate = 4 + co_text_footnote, img_boundary = self.update_region_contours(co_text_footnote, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "footnote-continued" in elements_with_artificial_class: + erosion_rate = 2 + dilation_rate = 4 + co_text_footnote_con, img_boundary = self.update_region_contours(co_text_footnote_con, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + img = np.zeros( (y_len,x_len,3) ) @@ -525,6 +561,10 @@ class pagexml2word: if 'textregions' in keys: if "paragraph" in types_text: img_poly=cv2.fillPoly(img, pts =co_text_paragraph, color=labels_rgb_color[ config_params['textregions']['paragraph']]) + if "footnote" in types_text: + img_poly=cv2.fillPoly(img, pts =co_text_footnote, color=labels_rgb_color[ config_params['textregions']['footnote']]) + if "footnote-continued" in types_text: + img_poly=cv2.fillPoly(img, pts =co_text_footnote_con, color=labels_rgb_color[ config_params['textregions']['footnote-continued']]) if "heading" in types_text: img_poly=cv2.fillPoly(img, pts =co_text_heading, color=labels_rgb_color[ config_params['textregions']['heading']]) if "header" in types_text: @@ -580,6 +620,12 @@ class pagexml2word: if "paragraph" in types_text: color_label = config_params['textregions']['paragraph'] img_poly=cv2.fillPoly(img, pts =co_text_paragraph, color=(color_label,color_label,color_label)) + if "footnote" in types_text: + color_label = config_params['textregions']['footnote'] + img_poly=cv2.fillPoly(img, pts =co_text_footnote, color=(color_label,color_label,color_label)) + if "footnote-continued" in types_text: + color_label = config_params['textregions']['footnote-continued'] + img_poly=cv2.fillPoly(img, pts =co_text_footnote_con, color=(color_label,color_label,color_label)) if "heading" in types_text: color_label = config_params['textregions']['heading'] img_poly=cv2.fillPoly(img, pts =co_text_heading, color=(color_label,color_label,color_label)) From a83d53c27d09c962c54f441e225c70fbd820900b Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Thu, 23 May 2024 17:14:31 +0200 Subject: [PATCH 052/123] use cases like textline, word and glyph are added --- train/custom_config_page2label.json | 11 +- train/pagexml2label.py | 1055 +++------------------------ 2 files changed, 93 insertions(+), 973 deletions(-) diff --git a/train/custom_config_page2label.json b/train/custom_config_page2label.json index 254f4df..d6320fa 100644 --- a/train/custom_config_page2label.json +++ b/train/custom_config_page2label.json @@ -1,8 +1,9 @@ { -"textregions":{"paragraph":1, "heading": 1, "header":1,"drop-capital": 1, "marginalia":1 ,"page-number":1 , "catch-word":1 ,"footnote": 1, "footnote-continued": 1}, -"imageregion":2, -"separatorregion":3, -"graphicregions" :{"handwritten-annotation":2, "decoration": 2, "signature": 2, "stamp": 2}, +"use_case": "layout", +"textregions":{"paragraph":1, "heading": 2, "header":2,"drop-capital": 3, "marginalia":4 ,"page-number":1 , "catch-word":1 ,"footnote": 1, "footnote-continued": 1}, +"imageregion":5, +"separatorregion":6, +"graphicregions" :{"handwritten-annotation":5, "decoration": 5, "signature": 5, "stamp": 5}, "artificial_class_on_boundry": ["paragraph","header", "heading", "marginalia", "page-number", "catch-word", "drop-capital","footnote", "footnote-continued"], -"artificial_class_label":4 +"artificial_class_label":7 } diff --git a/train/pagexml2label.py b/train/pagexml2label.py index 63b7acf..16cda8b 100644 --- a/train/pagexml2label.py +++ b/train/pagexml2label.py @@ -21,13 +21,12 @@ This classes.txt file is required for dhsegment tool. """ KERNEL = np.ones((5, 5), np.uint8) -class pagexml2word: - def __init__(self,dir_in, out_dir,output_type,experiment,layout_config): +class pagexml2label: + def __init__(self,dir_in, out_dir,output_type,config): self.dir=dir_in self.output_dir=out_dir self.output_type=output_type - self.experiment=experiment - self.layout_config=layout_config + self.config=config def get_content_of_dir(self): """ @@ -127,7 +126,82 @@ class pagexml2word: y_len=int(jj.attrib['imageHeight']) x_len=int(jj.attrib['imageWidth']) - if self.layout_config: + if self.config and (config_params['use_case']=='textline' or config_params['use_case']=='word' or config_params['use_case']=='glyph'): + keys = list(config_params.keys()) + if "artificial_class_label" in keys: + artificial_class_rgb_color = (255,255,0) + artificial_class_label = config_params['artificial_class_label'] + + textline_rgb_color = (255, 0, 0) + + if config_params['use_case']=='textline': + region_tags = np.unique([x for x in alltags if x.endswith('TextLine')]) + elif config_params['use_case']=='word': + region_tags = np.unique([x for x in alltags if x.endswith('Word')]) + elif config_params['use_case']=='glyph': + region_tags = np.unique([x for x in alltags if x.endswith('Glyph')]) + co_use_case = [] + + for tag in region_tags: + if config_params['use_case']=='textline': + tag_endings = ['}TextLine','}textline'] + elif config_params['use_case']=='word': + tag_endings = ['}Word','}word'] + elif config_params['use_case']=='glyph': + tag_endings = ['}Glyph','}glyph'] + + if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]): + for nn in root1.iter(tag): + c_t_in = [] + sumi = 0 + for vv in nn.iter(): + # check the format of coords + if vv.tag == link + 'Coords': + coords = bool(vv.attrib) + if coords: + p_h = vv.attrib['points'].split(' ') + c_t_in.append( + np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])) + break + else: + pass + + if vv.tag == link + 'Point': + c_t_in.append([int(np.float(vv.attrib['x'])), int(np.float(vv.attrib['y']))]) + sumi += 1 + elif vv.tag != link + 'Point' and sumi >= 1: + break + co_use_case.append(np.array(c_t_in)) + + + + if "artificial_class_label" in keys: + img_boundary = np.zeros((y_len, x_len)) + erosion_rate = 1 + dilation_rate = 3 + co_use_case, img_boundary = self.update_region_contours(co_use_case, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + + + img = np.zeros((y_len, x_len, 3)) + if self.output_type == '2d': + img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1)) + if "artificial_class_label" in keys: + img_poly[:,:][img_boundary[:,:]==1] = artificial_class_label + elif self.output_type == '3d': + img_poly = cv2.fillPoly(img, pts=co_use_case, color=textline_rgb_color) + if "artificial_class_label" in keys: + img_poly[:,:,0][img_boundary[:,:]==1] = artificial_class_rgb_color[0] + img_poly[:,:,1][img_boundary[:,:]==1] = artificial_class_rgb_color[1] + img_poly[:,:,2][img_boundary[:,:]==1] = artificial_class_rgb_color[2] + + try: + cv2.imwrite(self.output_dir + '/' + self.gt_list[index].split('-')[1].split('.')[0] + '.png', + img_poly) + except: + cv2.imwrite(self.output_dir + '/' + self.gt_list[index].split('.')[0] + '.png', img_poly) + + + if self.config and config_params['use_case']=='layout': keys = list(config_params.keys()) if "artificial_class_on_boundry" in keys: elements_with_artificial_class = list(config_params['artificial_class_on_boundry']) @@ -139,6 +213,7 @@ class pagexml2word: types_text_dict = config_params['textregions'] types_text = list(types_text_dict.keys()) types_text_label = list(types_text_dict.values()) + print(types_text) if 'graphicregions' in keys: types_graphic_dict = config_params['graphicregions'] types_graphic = list(types_graphic_dict.keys()) @@ -660,957 +735,6 @@ class pagexml2word: cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.png',img_poly ) - #print(values[0]) - if self.experiment=='word': - region_tags=np.unique([x for x in alltags if x.endswith('Word')]) - co_word=[] - - for tag in region_tags: - if tag.endswith('}Word') or tag.endswith('}word'): - #print('sth') - for nn in root1.iter(tag): - c_t_in=[] - sumi=0 - for vv in nn.iter(): - # check the format of coords - if vv.tag==link+'Coords': - coords=bool(vv.attrib) - if coords: - p_h=vv.attrib['points'].split(' ') - c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - break - else: - pass - - - if vv.tag==link+'Point': - c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - #print(vv.tag,'in') - elif vv.tag!=link+'Point' and sumi>=1: - break - co_word.append(np.array(c_t_in)) - - img = np.zeros( (y_len,x_len, 3) ) - if self.output_type == '2d': - img_poly=cv2.fillPoly(img, pts =co_word, color=(1,1,1)) - elif self.output_type == '3d': - img_poly=cv2.fillPoly(img, pts =co_word, color=(255,0,0)) - - try: - cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('-')[1].split('.')[0]+'.png',img_poly ) - except: - cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.png',img_poly ) - - - elif self.experiment=='glyph': - region_tags=np.unique([x for x in alltags if x.endswith('Glyph')]) - co_glyph=[] - - for tag in region_tags: - if tag.endswith('}Glyph') or tag.endswith('}glyph'): - #print('sth') - for nn in root1.iter(tag): - c_t_in=[] - sumi=0 - for vv in nn.iter(): - # check the format of coords - if vv.tag==link+'Coords': - coords=bool(vv.attrib) - if coords: - p_h=vv.attrib['points'].split(' ') - c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - break - else: - pass - - - if vv.tag==link+'Point': - c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - #print(vv.tag,'in') - elif vv.tag!=link+'Point' and sumi>=1: - break - co_glyph.append(np.array(c_t_in)) - - img = np.zeros( (y_len,x_len, 3) ) - if self.output_type == '2d': - img_poly=cv2.fillPoly(img, pts =co_glyph, color=(1,1,1)) - elif self.output_type == '3d': - img_poly=cv2.fillPoly(img, pts =co_glyph, color=(255,0,0)) - - try: - cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('-')[1].split('.')[0]+'.png',img_poly ) - except: - cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.png',img_poly ) - - elif self.experiment=='textline': - region_tags=np.unique([x for x in alltags if x.endswith('TextLine')]) - co_line=[] - - for tag in region_tags: - if tag.endswith('}TextLine') or tag.endswith('}textline'): - #print('sth') - for nn in root1.iter(tag): - c_t_in=[] - sumi=0 - for vv in nn.iter(): - # check the format of coords - if vv.tag==link+'Coords': - coords=bool(vv.attrib) - if coords: - p_h=vv.attrib['points'].split(' ') - c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - break - else: - pass - - - if vv.tag==link+'Point': - c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - #print(vv.tag,'in') - elif vv.tag!=link+'Point' and sumi>=1: - break - co_line.append(np.array(c_t_in)) - - img = np.zeros( (y_len,x_len, 3) ) - if self.output_type == '2d': - img_poly=cv2.fillPoly(img, pts =co_line, color=(1,1,1)) - elif self.output_type == '3d': - img_poly=cv2.fillPoly(img, pts =co_line, color=(255,0,0)) - - try: - cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('-')[1].split('.')[0]+'.png',img_poly ) - except: - cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.png',img_poly ) - - elif self.experiment == 'textline_new_concept': - region_tags = np.unique([x for x in alltags if x.endswith('TextLine')]) - co_line = [] - - for tag in region_tags: - if tag.endswith('}TextLine') or tag.endswith('}textline'): - # print('sth') - for nn in root1.iter(tag): - c_t_in = [] - sumi = 0 - for vv in nn.iter(): - # check the format of coords - if vv.tag == link + 'Coords': - coords = bool(vv.attrib) - if coords: - p_h = vv.attrib['points'].split(' ') - c_t_in.append( - np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])) - break - else: - pass - - if vv.tag == link + 'Point': - c_t_in.append([int(np.float(vv.attrib['x'])), int(np.float(vv.attrib['y']))]) - sumi += 1 - # print(vv.tag,'in') - elif vv.tag != link + 'Point' and sumi >= 1: - break - co_line.append(np.array(c_t_in)) - - img_boundary = np.zeros((y_len, x_len)) - co_textline_eroded = [] - for con in co_line: - # try: - img_boundary_in = np.zeros((y_len, x_len)) - img_boundary_in = cv2.fillPoly(img_boundary_in, pts=[con], color=(1, 1, 1)) - # print('bidiahhhhaaa') - - # img_boundary_in = cv2.erode(img_boundary_in[:,:], KERNEL, iterations=7)#asiatica - img_boundary_in = cv2.erode(img_boundary_in[:, :], KERNEL, iterations=1) - - pixel = 1 - min_size = 0 - con_eroded = self.return_contours_of_interested_region(img_boundary_in, pixel, min_size) - - try: - co_textline_eroded.append(con_eroded[0]) - except: - co_textline_eroded.append(con) - - img_boundary_in_dilated = cv2.dilate(img_boundary_in[:, :], KERNEL, iterations=3) - # img_boundary_in_dilated = cv2.dilate(img_boundary_in[:,:], KERNEL, iterations=5) - - boundary = img_boundary_in_dilated[:, :] - img_boundary_in[:, :] - - img_boundary[:, :][boundary[:, :] == 1] = 1 - - img = np.zeros((y_len, x_len, 3)) - if self.output_type == '2d': - img_poly = cv2.fillPoly(img, pts=co_textline_eroded, color=(1, 1, 1)) - img_poly[:, :][img_boundary[:, :] == 1] = 2 - elif self.output_type == '3d': - img_poly = cv2.fillPoly(img, pts=co_textline_eroded, color=(255, 0, 0)) - img_poly[:, :, 0][img_boundary[:, :] == 1] = 255 - img_poly[:, :, 1][img_boundary[:, :] == 1] = 125 - img_poly[:, :, 2][img_boundary[:, :] == 1] = 125 - - try: - cv2.imwrite(self.output_dir + '/' + self.gt_list[index].split('-')[1].split('.')[0] + '.png', - img_poly) - except: - cv2.imwrite(self.output_dir + '/' + self.gt_list[index].split('.')[0] + '.png', img_poly) - - elif self.experiment=='layout_for_main_regions': - region_tags=np.unique([x for x in alltags if x.endswith('Region')]) - #print(region_tags) - co_text=[] - co_sep=[] - co_img=[] - #co_graphic=[] - - for tag in region_tags: - if tag.endswith('}TextRegion') or tag.endswith('}Textregion'): - #print('sth') - for nn in root1.iter(tag): - print(nn.attrib['type']) - c_t_in=[] - sumi=0 - for vv in nn.iter(): - # check the format of coords - if vv.tag==link+'Coords': - coords=bool(vv.attrib) - if coords: - p_h=vv.attrib['points'].split(' ') - c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - break - else: - pass - - - if vv.tag==link+'Point': - c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - #print(vv.tag,'in') - elif vv.tag!=link+'Point' and sumi>=1: - break - co_text.append(np.array(c_t_in)) - - elif tag.endswith('}ImageRegion') or tag.endswith('}GraphicRegion') or tag.endswith('}imageregion'): - #print('sth') - for nn in root1.iter(tag): - c_t_in=[] - sumi=0 - for vv in nn.iter(): - # check the format of coords - if vv.tag==link+'Coords': - coords=bool(vv.attrib) - if coords: - p_h=vv.attrib['points'].split(' ') - c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - break - else: - pass - - - if vv.tag==link+'Point': - c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - #print(vv.tag,'in') - elif vv.tag!=link+'Point' and sumi>=1: - break - co_img.append(np.array(c_t_in)) - - elif tag.endswith('}SeparatorRegion') or tag.endswith('}separatorregion'): - #print('sth') - for nn in root1.iter(tag): - c_t_in=[] - sumi=0 - for vv in nn.iter(): - # check the format of coords - if vv.tag==link+'Coords': - coords=bool(vv.attrib) - if coords: - p_h=vv.attrib['points'].split(' ') - c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - break - else: - pass - - - if vv.tag==link+'Point': - c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - #print(vv.tag,'in') - elif vv.tag!=link+'Point' and sumi>=1: - break - co_sep.append(np.array(c_t_in)) - - img_poly = np.zeros( (y_len,x_len,3) ) - - - if self.output_type == '3d': - img_poly=cv2.fillPoly(img_poly, pts =co_text, color=(255,0,0)) - img_poly=cv2.fillPoly(img_poly, pts =co_img, color=(0,255,0)) - img_poly=cv2.fillPoly(img_poly, pts =co_sep, color=(0,0,255)) - ##img_poly=cv2.fillPoly(img, pts =co_graphic, color=(255,125,125)) - elif self.output_type == '2d': - img_poly=cv2.fillPoly(img_poly, pts =co_text, color=(1,1,1)) - img_poly=cv2.fillPoly(img_poly, pts =co_img, color=(2,2,2)) - img_poly=cv2.fillPoly(img_poly, pts =co_sep, color=(3,3,3)) - - try: - cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('-')[1].split('.')[0]+'.png',img_poly ) - except: - cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.png',img_poly ) - - elif self.experiment=='textregion': - region_tags=np.unique([x for x in alltags if x.endswith('TextRegion')]) - co_textregion=[] - - for tag in region_tags: - if tag.endswith('}TextRegion') or tag.endswith('}Textregion'): - #print('sth') - for nn in root1.iter(tag): - c_t_in=[] - sumi=0 - for vv in nn.iter(): - # check the format of coords - if vv.tag==link+'Coords': - coords=bool(vv.attrib) - if coords: - p_h=vv.attrib['points'].split(' ') - c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - break - else: - pass - - - if vv.tag==link+'Point': - c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - #print(vv.tag,'in') - elif vv.tag!=link+'Point' and sumi>=1: - break - co_textregion.append(np.array(c_t_in)) - - img = np.zeros( (y_len,x_len,3) ) - if self.output_type == '3d': - img_poly=cv2.fillPoly(img, pts =co_textregion, color=(255,0,0)) - elif self.output_type == '2d': - img_poly=cv2.fillPoly(img, pts =co_textregion, color=(1,1,1)) - - - try: - cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('-')[1].split('.')[0]+'.png',img_poly ) - except: - cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.png',img_poly ) - - elif self.experiment=='layout': - region_tags=np.unique([x for x in alltags if x.endswith('Region')]) - - co_text_paragraph=[] - co_text_drop=[] - co_text_heading=[] - co_text_header=[] - co_text_marginalia=[] - co_text_catch=[] - co_text_page_number=[] - co_text_signature_mark=[] - co_sep=[] - co_img=[] - co_table=[] - co_graphic=[] - co_graphic_text_annotation=[] - co_graphic_decoration=[] - co_noise=[] - - for tag in region_tags: - if tag.endswith('}TextRegion') or tag.endswith('}Textregion'): - for nn in root1.iter(tag): - c_t_in_drop=[] - c_t_in_paragraph=[] - c_t_in_heading=[] - c_t_in_header=[] - c_t_in_page_number=[] - c_t_in_signature_mark=[] - c_t_in_catch=[] - c_t_in_marginalia=[] - sumi=0 - for vv in nn.iter(): - # check the format of coords - if vv.tag==link+'Coords': - - coords=bool(vv.attrib) - if coords: - #print('birda1') - p_h=vv.attrib['points'].split(' ') - - - - if "type" in nn.attrib and nn.attrib['type']=='drop-capital': - #if nn.attrib['type']=='paragraph': - - c_t_in_drop.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - elif "type" in nn.attrib and nn.attrib['type']=='heading': - c_t_in_heading.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - - elif "type" in nn.attrib and nn.attrib['type']=='signature-mark': - - c_t_in_signature_mark.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - #print(c_t_in_paragraph) - elif "type" in nn.attrib and nn.attrib['type']=='header': - c_t_in_header.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - - elif "type" in nn.attrib and nn.attrib['type']=='catch-word': - c_t_in_catch.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - - elif "type" in nn.attrib and nn.attrib['type']=='page-number': - - c_t_in_page_number.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - #print(c_t_in_paragraph) - - elif "type" in nn.attrib and nn.attrib['type']=='marginalia': - - c_t_in_marginalia.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - #print(c_t_in_paragraph) - else: - - c_t_in_paragraph.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - #print(c_t_in_paragraph) - - break - else: - pass - - - if vv.tag==link+'Point': - if "type" in nn.attrib and nn.attrib['type']=='drop-capital': - #if nn.attrib['type']=='paragraph': - - c_t_in_drop.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - elif "type" in nn.attrib and nn.attrib['type']=='heading': - c_t_in_heading.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - - elif "type" in nn.attrib and nn.attrib['type']=='signature-mark': - - c_t_in_signature_mark.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - #print(c_t_in_paragraph) - sumi+=1 - elif "type" in nn.attrib and nn.attrib['type']=='header': - c_t_in_header.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - - elif "type" in nn.attrib and nn.attrib['type']=='catch-word': - c_t_in_catch.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - - elif "type" in nn.attrib and nn.attrib['type']=='page-number': - - c_t_in_page_number.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - #print(c_t_in_paragraph) - sumi+=1 - - elif "type" in nn.attrib and nn.attrib['type']=='marginalia': - - c_t_in_marginalia.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - #print(c_t_in_paragraph) - sumi+=1 - - else: - c_t_in_paragraph.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - #print(c_t_in_paragraph) - sumi+=1 - - #c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - - #print(vv.tag,'in') - elif vv.tag!=link+'Point' and sumi>=1: - break - - if len(c_t_in_drop)>0: - co_text_drop.append(np.array(c_t_in_drop)) - if len(c_t_in_paragraph)>0: - co_text_paragraph.append(np.array(c_t_in_paragraph)) - if len(c_t_in_heading)>0: - co_text_heading.append(np.array(c_t_in_heading)) - - if len(c_t_in_header)>0: - co_text_header.append(np.array(c_t_in_header)) - if len(c_t_in_page_number)>0: - co_text_page_number.append(np.array(c_t_in_page_number)) - if len(c_t_in_catch)>0: - co_text_catch.append(np.array(c_t_in_catch)) - - if len(c_t_in_signature_mark)>0: - co_text_signature_mark.append(np.array(c_t_in_signature_mark)) - - if len(c_t_in_marginalia)>0: - co_text_marginalia.append(np.array(c_t_in_marginalia)) - - - elif tag.endswith('}GraphicRegion') or tag.endswith('}graphicregion'): - #print('sth') - for nn in root1.iter(tag): - c_t_in=[] - c_t_in_text_annotation=[] - c_t_in_decoration=[] - sumi=0 - for vv in nn.iter(): - # check the format of coords - if vv.tag==link+'Coords': - coords=bool(vv.attrib) - if coords: - p_h=vv.attrib['points'].split(' ') - #c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - if "type" in nn.attrib and nn.attrib['type']=='handwritten-annotation': - #if nn.attrib['type']=='paragraph': - - c_t_in_text_annotation.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - elif "type" in nn.attrib and nn.attrib['type']=='decoration': - - c_t_in_decoration.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - #print(c_t_in_paragraph) - else: - c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - - - break - else: - pass - - - if vv.tag==link+'Point': - - if "type" in nn.attrib and nn.attrib['type']=='handwritten-annotation': - #if nn.attrib['type']=='paragraph': - - c_t_in_text_annotation.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - elif "type" in nn.attrib and nn.attrib['type']=='decoration': - - c_t_in_decoration.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - #print(c_t_in_paragraph) - sumi+=1 - else: - c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - if len(c_t_in_text_annotation)>0: - co_graphic_text_annotation.append(np.array(c_t_in_text_annotation)) - if len(c_t_in_decoration)>0: - co_graphic_decoration.append(np.array(c_t_in_decoration)) - if len(c_t_in)>0: - co_graphic.append(np.array(c_t_in)) - - - - elif tag.endswith('}ImageRegion') or tag.endswith('}imageregion'): - #print('sth') - for nn in root1.iter(tag): - c_t_in=[] - sumi=0 - for vv in nn.iter(): - # check the format of coords - if vv.tag==link+'Coords': - coords=bool(vv.attrib) - if coords: - p_h=vv.attrib['points'].split(' ') - c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - break - else: - pass - - - if vv.tag==link+'Point': - c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - #print(vv.tag,'in') - elif vv.tag!=link+'Point' and sumi>=1: - break - co_img.append(np.array(c_t_in)) - - elif tag.endswith('}SeparatorRegion') or tag.endswith('}separatorregion'): - #print('sth') - for nn in root1.iter(tag): - c_t_in=[] - sumi=0 - for vv in nn.iter(): - # check the format of coords - if vv.tag==link+'Coords': - coords=bool(vv.attrib) - if coords: - p_h=vv.attrib['points'].split(' ') - c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - break - else: - pass - - - if vv.tag==link+'Point': - c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - #print(vv.tag,'in') - elif vv.tag!=link+'Point' and sumi>=1: - break - co_sep.append(np.array(c_t_in)) - - - - elif tag.endswith('}TableRegion') or tag.endswith('}tableregion'): - #print('sth') - for nn in root1.iter(tag): - c_t_in=[] - sumi=0 - for vv in nn.iter(): - # check the format of coords - if vv.tag==link+'Coords': - coords=bool(vv.attrib) - if coords: - p_h=vv.attrib['points'].split(' ') - c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - break - else: - pass - - - if vv.tag==link+'Point': - c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - #print(vv.tag,'in') - elif vv.tag!=link+'Point' and sumi>=1: - break - co_table.append(np.array(c_t_in)) - - elif tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'): - #print('sth') - for nn in root1.iter(tag): - c_t_in=[] - sumi=0 - for vv in nn.iter(): - # check the format of coords - if vv.tag==link+'Coords': - coords=bool(vv.attrib) - if coords: - p_h=vv.attrib['points'].split(' ') - c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - break - else: - pass - - - if vv.tag==link+'Point': - c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - #print(vv.tag,'in') - elif vv.tag!=link+'Point' and sumi>=1: - break - co_noise.append(np.array(c_t_in)) - - - img = np.zeros( (y_len,x_len,3) ) - - if self.output_type == '3d': - img_poly=cv2.fillPoly(img, pts =co_text_paragraph, color=(255,0,0)) - - img_poly=cv2.fillPoly(img, pts =co_text_heading, color=(255,125,0)) - img_poly=cv2.fillPoly(img, pts =co_text_header, color=(255,0,125)) - img_poly=cv2.fillPoly(img, pts =co_text_catch, color=(125,255,125)) - img_poly=cv2.fillPoly(img, pts =co_text_signature_mark, color=(125,125,0)) - img_poly=cv2.fillPoly(img, pts =co_graphic_decoration, color=(0,125,255)) - img_poly=cv2.fillPoly(img, pts =co_text_page_number, color=(0,125,0)) - img_poly=cv2.fillPoly(img, pts =co_text_marginalia, color=(125,125,125)) - img_poly=cv2.fillPoly(img, pts =co_text_drop, color=(0,125,255)) - - img_poly=cv2.fillPoly(img, pts =co_graphic_text_annotation, color=(125,0,125)) - img_poly=cv2.fillPoly(img, pts =co_img, color=(0,255,0)) - img_poly=cv2.fillPoly(img, pts =co_sep, color=(0,0,255)) - img_poly=cv2.fillPoly(img, pts =co_table, color=(0,255,255)) - img_poly=cv2.fillPoly(img, pts =co_graphic, color=(255,125,125)) - img_poly=cv2.fillPoly(img, pts =co_noise, color=(255,0,255)) - elif self.output_type == '2d': - img_poly=cv2.fillPoly(img, pts =co_text_paragraph, color=(1,1,1)) - - img_poly=cv2.fillPoly(img, pts =co_text_heading, color=(2,2,2)) - img_poly=cv2.fillPoly(img, pts =co_text_header, color=(2,2,2)) - img_poly=cv2.fillPoly(img, pts =co_text_catch, color=(3,3,3)) - img_poly=cv2.fillPoly(img, pts =co_text_signature_mark, color=(4,4,4)) - img_poly=cv2.fillPoly(img, pts =co_graphic_decoration, color=(5,5,5)) - img_poly=cv2.fillPoly(img, pts =co_text_page_number, color=(6,6,6)) - img_poly=cv2.fillPoly(img, pts =co_text_marginalia, color=(7,7,7)) - img_poly=cv2.fillPoly(img, pts =co_text_drop, color=(8,8,8)) - - img_poly=cv2.fillPoly(img, pts =co_graphic_text_annotation, color=(9,9,9)) - img_poly=cv2.fillPoly(img, pts =co_img, color=(10,10,10)) - img_poly=cv2.fillPoly(img, pts =co_sep, color=(11,11,11)) - img_poly=cv2.fillPoly(img, pts =co_table, color=(12,12,12)) - img_poly=cv2.fillPoly(img, pts =co_graphic, color=(13,13,14)) - img_poly=cv2.fillPoly(img, pts =co_noise, color=(15,15,15)) - - try: - cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('-')[1].split('.')[0]+'.png',img_poly ) - except: - cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.png',img_poly ) - - - elif self.experiment=='layout_for_main_regions_new_concept': - region_tags=np.unique([x for x in alltags if x.endswith('Region')]) - #print(region_tags) - co_text=[] - co_sep=[] - co_img=[] - co_drop = [] - co_graphic=[] - co_table = [] - - for tag in region_tags: - if tag.endswith('}TextRegion') or tag.endswith('}Textregion'): - #print('sth') - for nn in root1.iter(tag): - c_t_in=[] - c_t_in_drop = [] - sumi=0 - for vv in nn.iter(): - # check the format of coords - if vv.tag==link+'Coords': - coords=bool(vv.attrib) - if coords: - p_h=vv.attrib['points'].split(' ') - if "type" in nn.attrib and nn.attrib['type']=='drop-capital': - c_t_in_drop.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - else: - c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - break - else: - pass - - - if vv.tag==link+'Point': - if "type" in nn.attrib and nn.attrib['type']=='drop-capital': - c_t_in_drop.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - else: - c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - #print(vv.tag,'in') - elif vv.tag!=link+'Point' and sumi>=1: - break - if len(c_t_in)>0: - co_text.append(np.array(c_t_in)) - if len(c_t_in_drop)>0: - co_drop.append(np.array(c_t_in_drop)) - - elif tag.endswith('}ImageRegion') or tag.endswith('}GraphicRegion') or tag.endswith('}imageregion'): - #print('sth') - for nn in root1.iter(tag): - c_t_in=[] - sumi=0 - for vv in nn.iter(): - # check the format of coords - if vv.tag==link+'Coords': - coords=bool(vv.attrib) - if coords: - p_h=vv.attrib['points'].split(' ') - c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - break - else: - pass - - - if vv.tag==link+'Point': - c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - #print(vv.tag,'in') - elif vv.tag!=link+'Point' and sumi>=1: - break - co_img.append(np.array(c_t_in)) - - elif tag.endswith('}SeparatorRegion') or tag.endswith('}separatorregion'): - #print('sth') - for nn in root1.iter(tag): - c_t_in=[] - sumi=0 - for vv in nn.iter(): - # check the format of coords - if vv.tag==link+'Coords': - coords=bool(vv.attrib) - if coords: - p_h=vv.attrib['points'].split(' ') - c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - break - else: - pass - - - if vv.tag==link+'Point': - c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - #print(vv.tag,'in') - elif vv.tag!=link+'Point' and sumi>=1: - break - co_sep.append(np.array(c_t_in)) - - elif tag.endswith('}TableRegion') or tag.endswith('}tableregion'): - #print('sth') - for nn in root1.iter(tag): - c_t_in=[] - sumi=0 - for vv in nn.iter(): - # check the format of coords - if vv.tag==link+'Coords': - coords=bool(vv.attrib) - if coords: - p_h=vv.attrib['points'].split(' ') - c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - break - else: - pass - - - if vv.tag==link+'Point': - c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - #print(vv.tag,'in') - elif vv.tag!=link+'Point' and sumi>=1: - break - co_table.append(np.array(c_t_in)) - - img_boundary = np.zeros( (y_len,x_len) ) - - - co_text_eroded = [] - for con in co_text: - #try: - img_boundary_in = np.zeros( (y_len,x_len) ) - img_boundary_in = cv2.fillPoly(img_boundary_in, pts=[con], color=(1, 1, 1)) - #print('bidiahhhhaaa') - - - - #img_boundary_in = cv2.erode(img_boundary_in[:,:], KERNEL, iterations=7)#asiatica - img_boundary_in = cv2.erode(img_boundary_in[:,:], KERNEL, iterations=2) - - pixel = 1 - min_size = 0 - con_eroded = self.return_contours_of_interested_region(img_boundary_in,pixel, min_size ) - - try: - co_text_eroded.append(con_eroded[0]) - except: - co_text_eroded.append(con) - - img_boundary_in_dilated = cv2.dilate(img_boundary_in[:,:], KERNEL, iterations=4) - #img_boundary_in_dilated = cv2.dilate(img_boundary_in[:,:], KERNEL, iterations=5) - - boundary = img_boundary_in_dilated[:,:] - img_boundary_in[:,:] - - img_boundary[:,:][boundary[:,:]==1] =1 - - - ###co_table_eroded = [] - ###for con in co_table: - ####try: - ###img_boundary_in = np.zeros( (y_len,x_len) ) - ###img_boundary_in = cv2.fillPoly(img_boundary_in, pts=[con], color=(1, 1, 1)) - ####print('bidiahhhhaaa') - - - - #####img_boundary_in = cv2.erode(img_boundary_in[:,:], KERNEL, iterations=7)#asiatica - ###img_boundary_in = cv2.erode(img_boundary_in[:,:], KERNEL, iterations=2) - - ###pixel = 1 - ###min_size = 0 - ###con_eroded = self.return_contours_of_interested_region(img_boundary_in,pixel, min_size ) - - ###try: - ###co_table_eroded.append(con_eroded[0]) - ###except: - ###co_table_eroded.append(con) - - ###img_boundary_in_dilated = cv2.dilate(img_boundary_in[:,:], KERNEL, iterations=4) - - ###boundary = img_boundary_in_dilated[:,:] - img_boundary_in[:,:] - - ###img_boundary[:,:][boundary[:,:]==1] =1 - #except: - #pass - - #for con in co_img: - #img_boundary_in = np.zeros( (y_len,x_len) ) - #img_boundary_in = cv2.fillPoly(img_boundary_in, pts=[con], color=(1, 1, 1)) - #img_boundary_in_dilated = cv2.dilate(img_boundary_in[:,:], KERNEL, iterations=3) - - #boundary = img_boundary_in_dilated[:,:] - img_boundary_in[:,:] - - #img_boundary[:,:][boundary[:,:]==1] =1 - - - #for con in co_sep: - - #img_boundary_in = np.zeros( (y_len,x_len) ) - #img_boundary_in = cv2.fillPoly(img_boundary_in, pts=[con], color=(1, 1, 1)) - #img_boundary_in_dilated = cv2.dilate(img_boundary_in[:,:], KERNEL, iterations=3) - - #boundary = img_boundary_in_dilated[:,:] - img_boundary_in[:,:] - - img_boundary[:,:][boundary[:,:]==1] =1 - for con in co_drop: - img_boundary_in = np.zeros( (y_len,x_len) ) - img_boundary_in = cv2.fillPoly(img_boundary_in, pts=[con], color=(1, 1, 1)) - img_boundary_in_dilated = cv2.dilate(img_boundary_in[:,:], KERNEL, iterations=3) - - boundary = img_boundary_in_dilated[:,:] - img_boundary_in[:,:] - - img_boundary[:,:][boundary[:,:]==1] =1 - - - img = np.zeros( (y_len,x_len,3) ) - - if self.output_type == '2d': - img_poly=cv2.fillPoly(img, pts =co_img, color=(2,2,2)) - - img_poly=cv2.fillPoly(img, pts =co_text_eroded, color=(1,1,1)) - ##img_poly=cv2.fillPoly(img, pts =co_graphic, color=(4,4,4)) - ###img_poly=cv2.fillPoly(img, pts =co_table, color=(1,1,1)) - - img_poly=cv2.fillPoly(img, pts =co_drop, color=(1,1,1)) - img_poly[:,:][img_boundary[:,:]==1] = 4 - img_poly=cv2.fillPoly(img, pts =co_sep, color=(3,3,3)) - elif self.output_type == '3d': - img_poly=cv2.fillPoly(img, pts =co_img, color=(0,255,0)) - img_poly=cv2.fillPoly(img, pts =co_text_eroded, color=(255,0,0)) - img_poly=cv2.fillPoly(img, pts =co_drop, color=(0,125,255)) - - img_poly[:,:,0][img_boundary[:,:]==1]=255 - img_poly[:,:,1][img_boundary[:,:]==1]=125 - img_poly[:,:,2][img_boundary[:,:]==1]=125 - - img_poly=cv2.fillPoly(img, pts =co_sep, color=(0,0,255)) - ##img_poly=cv2.fillPoly(img, pts =co_graphic, color=(255,125,125)) - - #print('yazdimmm',self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.png') - try: - #print('yazdimmm',self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.png') - cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('-')[1].split('.')[0]+'.png',img_poly ) - except: - cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.png',img_poly ) - - - - #except: - #pass def run(self,config_params): self.get_content_of_dir() self.get_images_of_ground_truth(config_params) @@ -1631,9 +755,9 @@ class pagexml2word: ) @click.option( - "--layout_config", - "-lc", - help="config file of prefered layout.", + "--config", + "-cfg", + help="config file of prefered layout or use case.", type=click.Path(exists=True, dir_okay=False), ) @@ -1642,21 +766,16 @@ class pagexml2word: "-to", help="this defines how output should be. A 2d image array or a 3d image array encoded with RGB color. Just pass 2d or 3d. The file will be saved one directory up. 2D image array is 3d but only information of one channel would be enough since all channels have the same values.", ) -@click.option( - "--experiment", - "-exp", - help="experiment of ineterst. Word , textline , glyph and textregion are desired options.", -) -def main(dir_xml,dir_out,type_output,experiment,layout_config): - if layout_config: - with open(layout_config) as f: +def main(dir_xml,dir_out,type_output,config): + if config: + with open(config) as f: config_params = json.load(f) else: print("passed") config_params = None - x=pagexml2word(dir_xml,dir_out,type_output,experiment, layout_config) + x=pagexml2label(dir_xml,dir_out,type_output, config) x.run(config_params) if __name__=="__main__": main() From 61487bf782238ff7af96927f2c0c9108191f9ad0 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Thu, 23 May 2024 17:36:23 +0200 Subject: [PATCH 053/123] use case printspace is added --- train/pagexml2label.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/train/pagexml2label.py b/train/pagexml2label.py index 16cda8b..94596db 100644 --- a/train/pagexml2label.py +++ b/train/pagexml2label.py @@ -126,7 +126,7 @@ class pagexml2label: y_len=int(jj.attrib['imageHeight']) x_len=int(jj.attrib['imageWidth']) - if self.config and (config_params['use_case']=='textline' or config_params['use_case']=='word' or config_params['use_case']=='glyph'): + if self.config and (config_params['use_case']=='textline' or config_params['use_case']=='word' or config_params['use_case']=='glyph' or config_params['use_case']=='printspace'): keys = list(config_params.keys()) if "artificial_class_label" in keys: artificial_class_rgb_color = (255,255,0) @@ -140,6 +140,9 @@ class pagexml2label: region_tags = np.unique([x for x in alltags if x.endswith('Word')]) elif config_params['use_case']=='glyph': region_tags = np.unique([x for x in alltags if x.endswith('Glyph')]) + elif config_params['use_case']=='printspace': + region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace')]) + co_use_case = [] for tag in region_tags: @@ -149,6 +152,8 @@ class pagexml2label: tag_endings = ['}Word','}word'] elif config_params['use_case']=='glyph': tag_endings = ['}Glyph','}glyph'] + elif config_params['use_case']=='printspace': + tag_endings = ['}PrintSpace','}printspace'] if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]): for nn in root1.iter(tag): From d346b317fb5dea9afefa4fd95587f0c8201cd5d7 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Fri, 24 May 2024 14:42:58 +0200 Subject: [PATCH 054/123] machine based reading order training dataset generator is added --- train/generate_gt_for_training.py | 194 +++++ train/gt_for_enhancement_creator.py | 31 - train/gt_gen_utils.py | 1239 +++++++++++++++++++++++++++ train/pagexml2label.py | 789 ----------------- 4 files changed, 1433 insertions(+), 820 deletions(-) create mode 100644 train/generate_gt_for_training.py delete mode 100644 train/gt_for_enhancement_creator.py create mode 100644 train/gt_gen_utils.py delete mode 100644 train/pagexml2label.py diff --git a/train/generate_gt_for_training.py b/train/generate_gt_for_training.py new file mode 100644 index 0000000..e296029 --- /dev/null +++ b/train/generate_gt_for_training.py @@ -0,0 +1,194 @@ +import click +import json +from gt_gen_utils import * +from tqdm import tqdm + +@click.group() +def main(): + pass + +@main.command() +@click.option( + "--dir_xml", + "-dx", + help="directory of GT page-xml files", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--dir_out", + "-do", + help="directory where ground truth images would be written", + type=click.Path(exists=True, file_okay=False), +) + +@click.option( + "--config", + "-cfg", + help="config file of prefered layout or use case.", + type=click.Path(exists=True, dir_okay=False), +) + +@click.option( + "--type_output", + "-to", + help="this defines how output should be. A 2d image array or a 3d image array encoded with RGB color. Just pass 2d or 3d. The file will be saved one directory up. 2D image array is 3d but only information of one channel would be enough since all channels have the same values.", +) + +def pagexml2label(dir_xml,dir_out,type_output,config): + if config: + with open(config) as f: + config_params = json.load(f) + else: + print("passed") + config_params = None + gt_list = get_content_of_dir(dir_xml) + get_images_of_ground_truth(gt_list,dir_xml,dir_out,type_output, config, config_params) + +@main.command() +@click.option( + "--dir_imgs", + "-dis", + help="directory of images with high resolution.", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--dir_out_images", + "-dois", + help="directory where degraded images will be written.", + type=click.Path(exists=True, file_okay=False), +) + +@click.option( + "--dir_out_labels", + "-dols", + help="directory where original images will be written as labels.", + type=click.Path(exists=True, file_okay=False), +) +def image_enhancement(dir_imgs, dir_out_images, dir_out_labels): + #dir_imgs = './training_data_sample_enhancement/images' + #dir_out_images = './training_data_sample_enhancement/images_gt' + #dir_out_labels = './training_data_sample_enhancement/labels_gt' + + ls_imgs = os.listdir(dir_imgs) + ls_scales = [ 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9] + + for img in tqdm(ls_imgs): + img_name = img.split('.')[0] + img_type = img.split('.')[1] + image = cv2.imread(os.path.join(dir_imgs, img)) + for i, scale in enumerate(ls_scales): + height_sc = int(image.shape[0]*scale) + width_sc = int(image.shape[1]*scale) + + image_down_scaled = resize_image(image, height_sc, width_sc) + image_back_to_org_scale = resize_image(image_down_scaled, image.shape[0], image.shape[1]) + + cv2.imwrite(os.path.join(dir_out_images, img_name+'_'+str(i)+'.'+img_type), image_back_to_org_scale) + cv2.imwrite(os.path.join(dir_out_labels, img_name+'_'+str(i)+'.'+img_type), image) + + +@main.command() +@click.option( + "--dir_xml", + "-dx", + help="directory of GT page-xml files", + type=click.Path(exists=True, file_okay=False), +) + +@click.option( + "--dir_out_modal_image", + "-domi", + help="directory where ground truth images would be written", + type=click.Path(exists=True, file_okay=False), +) + +@click.option( + "--dir_out_classes", + "-docl", + help="directory where ground truth classes would be written", + type=click.Path(exists=True, file_okay=False), +) + +@click.option( + "--input_height", + "-ih", + help="input_height", +) +@click.option( + "--input_width", + "-iw", + help="input_width", +) + +def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, input_height, input_width): + xml_files_ind = os.listdir(dir_xml) + input_height = int(input_height) + input_width = int(input_width) + + indexer_start= 0#55166 + max_area = 1 + min_area = 0.0001 + + for ind_xml in tqdm(xml_files_ind): + indexer = 0 + #print(ind_xml) + #print('########################') + xml_file = os.path.join(dir_xml,ind_xml ) + f_name = ind_xml.split('.')[0] + file_name, id_paragraph, id_header,co_text_paragraph,\ + co_text_header,tot_region_ref,x_len, y_len,index_tot_regions,img_poly = read_xml(xml_file) + + id_all_text = id_paragraph + id_header + co_text_all = co_text_paragraph + co_text_header + + + _, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, _ = find_new_features_of_contours(co_text_header) + + img_header_and_sep = np.zeros((y_len,x_len), dtype='uint8') + + for j in range(len(cy_main)): + img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12,int(x_min_main[j]):int(x_max_main[j]) ] = 1 + + + texts_corr_order_index = [index_tot_regions[tot_region_ref.index(i)] for i in id_all_text ] + texts_corr_order_index_int = [int(x) for x in texts_corr_order_index] + + + co_text_all, texts_corr_order_index_int = filter_contours_area_of_image(img_poly, co_text_all, texts_corr_order_index_int, max_area, min_area) + + arg_array = np.array(range(len(texts_corr_order_index_int))) + + labels_con = np.zeros((y_len,x_len,len(arg_array)),dtype='uint8') + for i in range(len(co_text_all)): + img_label = np.zeros((y_len,x_len,3),dtype='uint8') + img_label=cv2.fillPoly(img_label, pts =[co_text_all[i]], color=(1,1,1)) + + img_label[:,:,0][img_poly[:,:,0]==5] = 2 + img_label[:,:,0][img_header_and_sep[:,:]==1] = 3 + + labels_con[:,:,i] = img_label[:,:,0] + + for i in range(len(texts_corr_order_index_int)): + for j in range(len(texts_corr_order_index_int)): + if i!=j: + input_matrix = np.zeros((input_height,input_width,3)).astype(np.int8) + final_f_name = f_name+'_'+str(indexer+indexer_start) + order_class_condition = texts_corr_order_index_int[i]-texts_corr_order_index_int[j] + if order_class_condition<0: + class_type = 1 + else: + class_type = 0 + + input_matrix[:,:,0] = resize_image(labels_con[:,:,i], input_height, input_width) + input_matrix[:,:,1] = resize_image(img_poly[:,:,0], input_height, input_width) + input_matrix[:,:,2] = resize_image(labels_con[:,:,j], input_height, input_width) + + np.save(os.path.join(dir_out_classes,final_f_name+'.npy' ), class_type) + + cv2.imwrite(os.path.join(dir_out_modal_image,final_f_name+'.png' ), input_matrix) + indexer = indexer+1 + + + +if __name__ == "__main__": + main() diff --git a/train/gt_for_enhancement_creator.py b/train/gt_for_enhancement_creator.py deleted file mode 100644 index 9a4274f..0000000 --- a/train/gt_for_enhancement_creator.py +++ /dev/null @@ -1,31 +0,0 @@ -import cv2 -import os - -def resize_image(seg_in, input_height, input_width): - return cv2.resize(seg_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST) - - -dir_imgs = './training_data_sample_enhancement/images' -dir_out_imgs = './training_data_sample_enhancement/images_gt' -dir_out_labs = './training_data_sample_enhancement/labels_gt' - -ls_imgs = os.listdir(dir_imgs) - - -ls_scales = [ 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9] - - -for img in ls_imgs: - img_name = img.split('.')[0] - img_type = img.split('.')[1] - image = cv2.imread(os.path.join(dir_imgs, img)) - for i, scale in enumerate(ls_scales): - height_sc = int(image.shape[0]*scale) - width_sc = int(image.shape[1]*scale) - - image_down_scaled = resize_image(image, height_sc, width_sc) - image_back_to_org_scale = resize_image(image_down_scaled, image.shape[0], image.shape[1]) - - cv2.imwrite(os.path.join(dir_out_imgs, img_name+'_'+str(i)+'.'+img_type), image_back_to_org_scale) - cv2.imwrite(os.path.join(dir_out_labs, img_name+'_'+str(i)+'.'+img_type), image) - diff --git a/train/gt_gen_utils.py b/train/gt_gen_utils.py new file mode 100644 index 0000000..9862e29 --- /dev/null +++ b/train/gt_gen_utils.py @@ -0,0 +1,1239 @@ +import click +import sys +import os +import numpy as np +import warnings +import xml.etree.ElementTree as ET +from tqdm import tqdm +import cv2 +from shapely import geometry +from pathlib import Path + + +KERNEL = np.ones((5, 5), np.uint8) + +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + +def get_content_of_dir(dir_in): + """ + Listing all ground truth page xml files. All files are needed to have xml format. + """ + + gt_all=os.listdir(dir_in) + gt_list=[file for file in gt_all if file.split('.')[ len(file.split('.'))-1 ]=='xml' ] + return gt_list + +def return_parent_contours(contours, hierarchy): + contours_parent = [contours[i] for i in range(len(contours)) if hierarchy[0][i][3] == -1] + return contours_parent +def filter_contours_area_of_image_tables(image, contours, hierarchy, max_area, min_area): + found_polygons_early = list() + + jv = 0 + for c in contours: + if len(c) < 3: # A polygon cannot have less than 3 points + continue + + polygon = geometry.Polygon([point[0] for point in c]) + # area = cv2.contourArea(c) + area = polygon.area + ##print(np.prod(thresh.shape[:2])) + # Check that polygon has area greater than minimal area + # print(hierarchy[0][jv][3],hierarchy ) + if area >= min_area * np.prod(image.shape[:2]) and area <= max_area * np.prod(image.shape[:2]): # and hierarchy[0][jv][3]==-1 : + # print(c[0][0][1]) + found_polygons_early.append(np.array([[point] for point in polygon.exterior.coords], dtype=np.int32)) + jv += 1 + return found_polygons_early + +def filter_contours_area_of_image(image, contours, order_index, max_area, min_area): + found_polygons_early = list() + order_index_filtered = list() + #jv = 0 + for jv, c in enumerate(contours): + #print(len(c[0])) + c = c[0] + if len(c) < 3: # A polygon cannot have less than 3 points + continue + c_e = [point for point in c] + #print(c_e) + polygon = geometry.Polygon(c_e) + area = polygon.area + #print(area,'area') + if area >= min_area * np.prod(image.shape[:2]) and area <= max_area * np.prod(image.shape[:2]): # and hierarchy[0][jv][3]==-1 : + found_polygons_early.append(np.array([[point] for point in polygon.exterior.coords], dtype=np.uint)) + order_index_filtered.append(order_index[jv]) + #jv += 1 + return found_polygons_early, order_index_filtered + +def return_contours_of_interested_region(region_pre_p, pixel, min_area=0.0002): + + # pixels of images are identified by 5 + if len(region_pre_p.shape) == 3: + cnts_images = (region_pre_p[:, :, 0] == pixel) * 1 + else: + cnts_images = (region_pre_p[:, :] == pixel) * 1 + cnts_images = cnts_images.astype(np.uint8) + cnts_images = np.repeat(cnts_images[:, :, np.newaxis], 3, axis=2) + imgray = cv2.cvtColor(cnts_images, cv2.COLOR_BGR2GRAY) + ret, thresh = cv2.threshold(imgray, 0, 255, 0) + + contours_imgs, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + + contours_imgs = return_parent_contours(contours_imgs, hierarchy) + contours_imgs = filter_contours_area_of_image_tables(thresh, contours_imgs, hierarchy, max_area=1, min_area=min_area) + + return contours_imgs +def update_region_contours(co_text, img_boundary, erosion_rate, dilation_rate, y_len, x_len): + co_text_eroded = [] + for con in co_text: + #try: + img_boundary_in = np.zeros( (y_len,x_len) ) + img_boundary_in = cv2.fillPoly(img_boundary_in, pts=[con], color=(1, 1, 1)) + #print('bidiahhhhaaa') + + + + #img_boundary_in = cv2.erode(img_boundary_in[:,:], KERNEL, iterations=7)#asiatica + if erosion_rate > 0: + img_boundary_in = cv2.erode(img_boundary_in[:,:], KERNEL, iterations=erosion_rate) + + pixel = 1 + min_size = 0 + con_eroded = return_contours_of_interested_region(img_boundary_in,pixel, min_size ) + + try: + co_text_eroded.append(con_eroded[0]) + except: + co_text_eroded.append(con) + + + img_boundary_in_dilated = cv2.dilate(img_boundary_in[:,:], KERNEL, iterations=dilation_rate) + #img_boundary_in_dilated = cv2.dilate(img_boundary_in[:,:], KERNEL, iterations=5) + + boundary = img_boundary_in_dilated[:,:] - img_boundary_in[:,:] + + img_boundary[:,:][boundary[:,:]==1] =1 + return co_text_eroded, img_boundary +def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_file, config_params): + """ + Reading the page xml files and write the ground truth images into given output directory. + """ + ## to do: add footnote to text regions + for index in tqdm(range(len(gt_list))): + #try: + tree1 = ET.parse(dir_in+'/'+gt_list[index]) + root1=tree1.getroot() + alltags=[elem.tag for elem in root1.iter()] + link=alltags[0].split('}')[0]+'}' + + + + for jj in root1.iter(link+'Page'): + y_len=int(jj.attrib['imageHeight']) + x_len=int(jj.attrib['imageWidth']) + + if config_file and (config_params['use_case']=='textline' or config_params['use_case']=='word' or config_params['use_case']=='glyph' or config_params['use_case']=='printspace'): + keys = list(config_params.keys()) + if "artificial_class_label" in keys: + artificial_class_rgb_color = (255,255,0) + artificial_class_label = config_params['artificial_class_label'] + + textline_rgb_color = (255, 0, 0) + + if config_params['use_case']=='textline': + region_tags = np.unique([x for x in alltags if x.endswith('TextLine')]) + elif config_params['use_case']=='word': + region_tags = np.unique([x for x in alltags if x.endswith('Word')]) + elif config_params['use_case']=='glyph': + region_tags = np.unique([x for x in alltags if x.endswith('Glyph')]) + elif config_params['use_case']=='printspace': + region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace')]) + + co_use_case = [] + + for tag in region_tags: + if config_params['use_case']=='textline': + tag_endings = ['}TextLine','}textline'] + elif config_params['use_case']=='word': + tag_endings = ['}Word','}word'] + elif config_params['use_case']=='glyph': + tag_endings = ['}Glyph','}glyph'] + elif config_params['use_case']=='printspace': + tag_endings = ['}PrintSpace','}printspace'] + + if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]): + for nn in root1.iter(tag): + c_t_in = [] + sumi = 0 + for vv in nn.iter(): + # check the format of coords + if vv.tag == link + 'Coords': + coords = bool(vv.attrib) + if coords: + p_h = vv.attrib['points'].split(' ') + c_t_in.append( + np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])) + break + else: + pass + + if vv.tag == link + 'Point': + c_t_in.append([int(np.float(vv.attrib['x'])), int(np.float(vv.attrib['y']))]) + sumi += 1 + elif vv.tag != link + 'Point' and sumi >= 1: + break + co_use_case.append(np.array(c_t_in)) + + + + if "artificial_class_label" in keys: + img_boundary = np.zeros((y_len, x_len)) + erosion_rate = 1 + dilation_rate = 3 + co_use_case, img_boundary = update_region_contours(co_use_case, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + + + img = np.zeros((y_len, x_len, 3)) + if output_type == '2d': + img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1)) + if "artificial_class_label" in keys: + img_poly[:,:][img_boundary[:,:]==1] = artificial_class_label + elif output_type == '3d': + img_poly = cv2.fillPoly(img, pts=co_use_case, color=textline_rgb_color) + if "artificial_class_label" in keys: + img_poly[:,:,0][img_boundary[:,:]==1] = artificial_class_rgb_color[0] + img_poly[:,:,1][img_boundary[:,:]==1] = artificial_class_rgb_color[1] + img_poly[:,:,2][img_boundary[:,:]==1] = artificial_class_rgb_color[2] + + try: + cv2.imwrite(output_dir + '/' + gt_list[index].split('-')[1].split('.')[0] + '.png', + img_poly) + except: + cv2.imwrite(output_dir + '/' + gt_list[index].split('.')[0] + '.png', img_poly) + + + if config_file and config_params['use_case']=='layout': + keys = list(config_params.keys()) + if "artificial_class_on_boundry" in keys: + elements_with_artificial_class = list(config_params['artificial_class_on_boundry']) + artificial_class_rgb_color = (255,255,0) + artificial_class_label = config_params['artificial_class_label'] + #values = config_params.values() + + if 'textregions' in keys: + types_text_dict = config_params['textregions'] + types_text = list(types_text_dict.keys()) + types_text_label = list(types_text_dict.values()) + print(types_text) + if 'graphicregions' in keys: + types_graphic_dict = config_params['graphicregions'] + types_graphic = list(types_graphic_dict.keys()) + types_graphic_label = list(types_graphic_dict.values()) + + + labels_rgb_color = [ (0,0,0), (255,0,0), (255,125,0), (255,0,125), (125,255,125), (125,125,0), (0,125,255), (0,125,0), (125,125,125), (255,0,255), (125,0,125), (0,255,0),(0,0,255), (0,255,255), (255,125,125), (0,125,125), (0,255,125), (255,125,255), (125,255,0)] + + region_tags=np.unique([x for x in alltags if x.endswith('Region')]) + + co_text_paragraph=[] + co_text_footnote=[] + co_text_footnote_con=[] + co_text_drop=[] + co_text_heading=[] + co_text_header=[] + co_text_marginalia=[] + co_text_catch=[] + co_text_page_number=[] + co_text_signature_mark=[] + co_sep=[] + co_img=[] + co_table=[] + co_graphic_signature=[] + co_graphic_text_annotation=[] + co_graphic_decoration=[] + co_graphic_stamp=[] + co_noise=[] + + for tag in region_tags: + if 'textregions' in keys: + if tag.endswith('}TextRegion') or tag.endswith('}Textregion'): + for nn in root1.iter(tag): + c_t_in_drop=[] + c_t_in_paragraph=[] + c_t_in_heading=[] + c_t_in_header=[] + c_t_in_page_number=[] + c_t_in_signature_mark=[] + c_t_in_catch=[] + c_t_in_marginalia=[] + c_t_in_footnote=[] + c_t_in_footnote_con=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + + coords=bool(vv.attrib) + if coords: + #print('birda1') + p_h=vv.attrib['points'].split(' ') + + if "drop-capital" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='drop-capital': + c_t_in_drop.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + if "footnote" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='footnote': + c_t_in_footnote.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + if "footnote-continued" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='footnote-continued': + c_t_in_footnote_con.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + if "heading" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='heading': + c_t_in_heading.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + if "signature-mark" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='signature-mark': + c_t_in_signature_mark.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + if "header" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='header': + c_t_in_header.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + if "catch-word" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='catch-word': + c_t_in_catch.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + if "page-number" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='page-number': + c_t_in_page_number.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + if "marginalia" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='marginalia': + c_t_in_marginalia.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + if "paragraph" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='paragraph': + c_t_in_paragraph.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + + break + else: + pass + + + if vv.tag==link+'Point': + if "drop-capital" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='drop-capital': + c_t_in_drop.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + if "footnote" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='footnote': + c_t_in_footnote.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + if "footnote-continued" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='footnote-continued': + c_t_in_footnote_con.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + if "heading" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='heading': + c_t_in_heading.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + if "signature-mark" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='signature-mark': + c_t_in_signature_mark.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + if "header" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='header': + c_t_in_header.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + if "catch-word" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='catch-word': + c_t_in_catch.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + if "page-number" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='page-number': + c_t_in_page_number.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + if "marginalia" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='marginalia': + c_t_in_marginalia.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + if "paragraph" in types_text: + if "type" in nn.attrib and nn.attrib['type']=='paragraph': + c_t_in_paragraph.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + + elif vv.tag!=link+'Point' and sumi>=1: + break + + if len(c_t_in_drop)>0: + co_text_drop.append(np.array(c_t_in_drop)) + if len(c_t_in_footnote_con)>0: + co_text_footnote_con.append(np.array(c_t_in_footnote_con)) + if len(c_t_in_footnote)>0: + co_text_footnote.append(np.array(c_t_in_footnote)) + if len(c_t_in_paragraph)>0: + co_text_paragraph.append(np.array(c_t_in_paragraph)) + if len(c_t_in_heading)>0: + co_text_heading.append(np.array(c_t_in_heading)) + + if len(c_t_in_header)>0: + co_text_header.append(np.array(c_t_in_header)) + if len(c_t_in_page_number)>0: + co_text_page_number.append(np.array(c_t_in_page_number)) + if len(c_t_in_catch)>0: + co_text_catch.append(np.array(c_t_in_catch)) + + if len(c_t_in_signature_mark)>0: + co_text_signature_mark.append(np.array(c_t_in_signature_mark)) + + if len(c_t_in_marginalia)>0: + co_text_marginalia.append(np.array(c_t_in_marginalia)) + + + if 'graphicregions' in keys: + if tag.endswith('}GraphicRegion') or tag.endswith('}graphicregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in_stamp=[] + c_t_in_text_annotation=[] + c_t_in_decoration=[] + c_t_in_signature=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + if "handwritten-annotation" in types_graphic: + if "type" in nn.attrib and nn.attrib['type']=='handwritten-annotation': + c_t_in_text_annotation.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + if "decoration" in types_graphic: + if "type" in nn.attrib and nn.attrib['type']=='decoration': + c_t_in_decoration.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + if "stamp" in types_graphic: + if "type" in nn.attrib and nn.attrib['type']=='stamp': + c_t_in_stamp.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + if "signature" in types_graphic: + if "type" in nn.attrib and nn.attrib['type']=='signature': + c_t_in_signature.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + + + break + else: + pass + + + if vv.tag==link+'Point': + if "handwritten-annotation" in types_graphic: + if "type" in nn.attrib and nn.attrib['type']=='handwritten-annotation': + c_t_in_text_annotation.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + if "decoration" in types_graphic: + if "type" in nn.attrib and nn.attrib['type']=='decoration': + c_t_in_decoration.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + if "stamp" in types_graphic: + if "type" in nn.attrib and nn.attrib['type']=='stamp': + c_t_in_stamp.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + if "signature" in types_graphic: + if "type" in nn.attrib and nn.attrib['type']=='signature': + c_t_in_signature.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + if len(c_t_in_text_annotation)>0: + co_graphic_text_annotation.append(np.array(c_t_in_text_annotation)) + if len(c_t_in_decoration)>0: + co_graphic_decoration.append(np.array(c_t_in_decoration)) + if len(c_t_in_stamp)>0: + co_graphic_stamp.append(np.array(c_t_in_stamp)) + if len(c_t_in_signature)>0: + co_graphic_signature.append(np.array(c_t_in_signature)) + + if 'imageregion' in keys: + if tag.endswith('}ImageRegion') or tag.endswith('}imageregion'): + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + elif vv.tag!=link+'Point' and sumi>=1: + break + co_img.append(np.array(c_t_in)) + + + if 'separatorregion' in keys: + if tag.endswith('}SeparatorRegion') or tag.endswith('}separatorregion'): + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + elif vv.tag!=link+'Point' and sumi>=1: + break + co_sep.append(np.array(c_t_in)) + + + + if 'tableregion' in keys: + if tag.endswith('}TableRegion') or tag.endswith('}tableregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_table.append(np.array(c_t_in)) + + if 'noiseregion' in keys: + if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_noise.append(np.array(c_t_in)) + + if "artificial_class_on_boundry" in keys: + img_boundary = np.zeros( (y_len,x_len) ) + if "paragraph" in elements_with_artificial_class: + erosion_rate = 2 + dilation_rate = 4 + co_text_paragraph, img_boundary = update_region_contours(co_text_paragraph, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "drop-capital" in elements_with_artificial_class: + erosion_rate = 0 + dilation_rate = 4 + co_text_drop, img_boundary = update_region_contours(co_text_drop, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "catch-word" in elements_with_artificial_class: + erosion_rate = 0 + dilation_rate = 4 + co_text_catch, img_boundary = update_region_contours(co_text_catch, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "page-number" in elements_with_artificial_class: + erosion_rate = 0 + dilation_rate = 4 + co_text_page_number, img_boundary = update_region_contours(co_text_page_number, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "header" in elements_with_artificial_class: + erosion_rate = 1 + dilation_rate = 4 + co_text_header, img_boundary = update_region_contours(co_text_header, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "heading" in elements_with_artificial_class: + erosion_rate = 1 + dilation_rate = 4 + co_text_heading, img_boundary = update_region_contours(co_text_heading, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "signature-mark" in elements_with_artificial_class: + erosion_rate = 1 + dilation_rate = 4 + co_text_signature_mark, img_boundary = update_region_contours(co_text_signature_mark, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "marginalia" in elements_with_artificial_class: + erosion_rate = 2 + dilation_rate = 4 + co_text_marginalia, img_boundary = update_region_contours(co_text_marginalia, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "footnote" in elements_with_artificial_class: + erosion_rate = 2 + dilation_rate = 4 + co_text_footnote, img_boundary = update_region_contours(co_text_footnote, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "footnote-continued" in elements_with_artificial_class: + erosion_rate = 2 + dilation_rate = 4 + co_text_footnote_con, img_boundary = update_region_contours(co_text_footnote_con, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + + + + img = np.zeros( (y_len,x_len,3) ) + + if output_type == '3d': + + if 'graphicregions' in keys: + if "handwritten-annotation" in types_graphic: + img_poly=cv2.fillPoly(img, pts =co_graphic_text_annotation, color=labels_rgb_color[ config_params['graphicregions']['handwritten-annotation']]) + if "signature" in types_graphic: + img_poly=cv2.fillPoly(img, pts =co_graphic_signature, color=labels_rgb_color[ config_params['graphicregions']['signature']]) + if "decoration" in types_graphic: + img_poly=cv2.fillPoly(img, pts =co_graphic_decoration, color=labels_rgb_color[ config_params['graphicregions']['decoration']]) + if "stamp" in types_graphic: + img_poly=cv2.fillPoly(img, pts =co_graphic_stamp, color=labels_rgb_color[ config_params['graphicregions']['stamp']]) + + if 'imageregion' in keys: + img_poly=cv2.fillPoly(img, pts =co_img, color=labels_rgb_color[ config_params['imageregion']]) + if 'separatorregion' in keys: + img_poly=cv2.fillPoly(img, pts =co_sep, color=labels_rgb_color[ config_params['separatorregion']]) + if 'tableregion' in keys: + img_poly=cv2.fillPoly(img, pts =co_table, color=labels_rgb_color[ config_params['tableregion']]) + if 'noiseregion' in keys: + img_poly=cv2.fillPoly(img, pts =co_noise, color=labels_rgb_color[ config_params['noiseregion']]) + + if 'textregions' in keys: + if "paragraph" in types_text: + img_poly=cv2.fillPoly(img, pts =co_text_paragraph, color=labels_rgb_color[ config_params['textregions']['paragraph']]) + if "footnote" in types_text: + img_poly=cv2.fillPoly(img, pts =co_text_footnote, color=labels_rgb_color[ config_params['textregions']['footnote']]) + if "footnote-continued" in types_text: + img_poly=cv2.fillPoly(img, pts =co_text_footnote_con, color=labels_rgb_color[ config_params['textregions']['footnote-continued']]) + if "heading" in types_text: + img_poly=cv2.fillPoly(img, pts =co_text_heading, color=labels_rgb_color[ config_params['textregions']['heading']]) + if "header" in types_text: + img_poly=cv2.fillPoly(img, pts =co_text_header, color=labels_rgb_color[ config_params['textregions']['header']]) + if "catch-word" in types_text: + img_poly=cv2.fillPoly(img, pts =co_text_catch, color=labels_rgb_color[ config_params['textregions']['catch-word']]) + if "signature-mark" in types_text: + img_poly=cv2.fillPoly(img, pts =co_text_signature_mark, color=labels_rgb_color[ config_params['textregions']['signature-mark']]) + if "page-number" in types_text: + img_poly=cv2.fillPoly(img, pts =co_text_page_number, color=labels_rgb_color[ config_params['textregions']['page-number']]) + if "marginalia" in types_text: + img_poly=cv2.fillPoly(img, pts =co_text_marginalia, color=labels_rgb_color[ config_params['textregions']['marginalia']]) + if "drop-capital" in types_text: + img_poly=cv2.fillPoly(img, pts =co_text_drop, color=labels_rgb_color[ config_params['textregions']['drop-capital']]) + + if "artificial_class_on_boundry" in keys: + img_poly[:,:,0][img_boundary[:,:]==1] = artificial_class_rgb_color[0] + img_poly[:,:,1][img_boundary[:,:]==1] = artificial_class_rgb_color[1] + img_poly[:,:,2][img_boundary[:,:]==1] = artificial_class_rgb_color[2] + + + + + elif output_type == '2d': + if 'graphicregions' in keys: + if "handwritten-annotation" in types_graphic: + color_label = config_params['graphicregions']['handwritten-annotation'] + img_poly=cv2.fillPoly(img, pts =co_graphic_text_annotation, color=(color_label,color_label,color_label)) + if "signature" in types_graphic: + color_label = config_params['graphicregions']['signature'] + img_poly=cv2.fillPoly(img, pts =co_graphic_signature, color=(color_label,color_label,color_label)) + if "decoration" in types_graphic: + color_label = config_params['graphicregions']['decoration'] + img_poly=cv2.fillPoly(img, pts =co_graphic_decoration, color=(color_label,color_label,color_label)) + if "stamp" in types_graphic: + color_label = config_params['graphicregions']['stamp'] + img_poly=cv2.fillPoly(img, pts =co_graphic_stamp, color=(color_label,color_label,color_label)) + + if 'imageregion' in keys: + color_label = config_params['imageregion'] + img_poly=cv2.fillPoly(img, pts =co_img, color=(color_label,color_label,color_label)) + if 'separatorregion' in keys: + color_label = config_params['separatorregion'] + img_poly=cv2.fillPoly(img, pts =co_sep, color=(color_label,color_label,color_label)) + if 'tableregion' in keys: + color_label = config_params['tableregion'] + img_poly=cv2.fillPoly(img, pts =co_table, color=(color_label,color_label,color_label)) + if 'noiseregion' in keys: + color_label = config_params['noiseregion'] + img_poly=cv2.fillPoly(img, pts =co_noise, color=(color_label,color_label,color_label)) + + if 'textregions' in keys: + if "paragraph" in types_text: + color_label = config_params['textregions']['paragraph'] + img_poly=cv2.fillPoly(img, pts =co_text_paragraph, color=(color_label,color_label,color_label)) + if "footnote" in types_text: + color_label = config_params['textregions']['footnote'] + img_poly=cv2.fillPoly(img, pts =co_text_footnote, color=(color_label,color_label,color_label)) + if "footnote-continued" in types_text: + color_label = config_params['textregions']['footnote-continued'] + img_poly=cv2.fillPoly(img, pts =co_text_footnote_con, color=(color_label,color_label,color_label)) + if "heading" in types_text: + color_label = config_params['textregions']['heading'] + img_poly=cv2.fillPoly(img, pts =co_text_heading, color=(color_label,color_label,color_label)) + if "header" in types_text: + color_label = config_params['textregions']['header'] + img_poly=cv2.fillPoly(img, pts =co_text_header, color=(color_label,color_label,color_label)) + if "catch-word" in types_text: + color_label = config_params['textregions']['catch-word'] + img_poly=cv2.fillPoly(img, pts =co_text_catch, color=(color_label,color_label,color_label)) + if "signature-mark" in types_text: + color_label = config_params['textregions']['signature-mark'] + img_poly=cv2.fillPoly(img, pts =co_text_signature_mark, color=(color_label,color_label,color_label)) + if "page-number" in types_text: + color_label = config_params['textregions']['page-number'] + img_poly=cv2.fillPoly(img, pts =co_text_page_number, color=(color_label,color_label,color_label)) + if "marginalia" in types_text: + color_label = config_params['textregions']['marginalia'] + img_poly=cv2.fillPoly(img, pts =co_text_marginalia, color=(color_label,color_label,color_label)) + if "drop-capital" in types_text: + color_label = config_params['textregions']['drop-capital'] + img_poly=cv2.fillPoly(img, pts =co_text_drop, color=(color_label,color_label,color_label)) + + if "artificial_class_on_boundry" in keys: + img_poly[:,:][img_boundary[:,:]==1] = artificial_class_label + + + + + try: + cv2.imwrite(output_dir+'/'+gt_list[index].split('-')[1].split('.')[0]+'.png',img_poly ) + except: + cv2.imwrite(output_dir+'/'+gt_list[index].split('.')[0]+'.png',img_poly ) + + + +def find_new_features_of_contours(contours_main): + + #print(contours_main[0][0][:, 0]) + + areas_main = np.array([cv2.contourArea(contours_main[j]) for j in range(len(contours_main))]) + M_main = [cv2.moments(contours_main[j]) for j in range(len(contours_main))] + cx_main = [(M_main[j]["m10"] / (M_main[j]["m00"] + 1e-32)) for j in range(len(M_main))] + cy_main = [(M_main[j]["m01"] / (M_main[j]["m00"] + 1e-32)) for j in range(len(M_main))] + try: + x_min_main = np.array([np.min(contours_main[j][0][:, 0]) for j in range(len(contours_main))]) + + argmin_x_main = np.array([np.argmin(contours_main[j][0][:, 0]) for j in range(len(contours_main))]) + + x_min_from_argmin = np.array([contours_main[j][0][argmin_x_main[j], 0] for j in range(len(contours_main))]) + y_corr_x_min_from_argmin = np.array([contours_main[j][0][argmin_x_main[j], 1] for j in range(len(contours_main))]) + + x_max_main = np.array([np.max(contours_main[j][0][:, 0]) for j in range(len(contours_main))]) + + y_min_main = np.array([np.min(contours_main[j][0][:, 1]) for j in range(len(contours_main))]) + y_max_main = np.array([np.max(contours_main[j][0][:, 1]) for j in range(len(contours_main))]) + except: + x_min_main = np.array([np.min(contours_main[j][:, 0]) for j in range(len(contours_main))]) + + argmin_x_main = np.array([np.argmin(contours_main[j][:, 0]) for j in range(len(contours_main))]) + + x_min_from_argmin = np.array([contours_main[j][argmin_x_main[j], 0] for j in range(len(contours_main))]) + y_corr_x_min_from_argmin = np.array([contours_main[j][argmin_x_main[j], 1] for j in range(len(contours_main))]) + + x_max_main = np.array([np.max(contours_main[j][:, 0]) for j in range(len(contours_main))]) + + y_min_main = np.array([np.min(contours_main[j][:, 1]) for j in range(len(contours_main))]) + y_max_main = np.array([np.max(contours_main[j][:, 1]) for j in range(len(contours_main))]) + + # dis_x=np.abs(x_max_main-x_min_main) + + return cx_main, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, y_corr_x_min_from_argmin +def read_xml(xml_file): + file_name = Path(xml_file).stem + tree1 = ET.parse(xml_file) + root1=tree1.getroot() + alltags=[elem.tag for elem in root1.iter()] + link=alltags[0].split('}')[0]+'}' + + index_tot_regions = [] + tot_region_ref = [] + + for jj in root1.iter(link+'Page'): + y_len=int(jj.attrib['imageHeight']) + x_len=int(jj.attrib['imageWidth']) + + + for jj in root1.iter(link+'RegionRefIndexed'): + index_tot_regions.append(jj.attrib['index']) + tot_region_ref.append(jj.attrib['regionRef']) + + region_tags=np.unique([x for x in alltags if x.endswith('Region')]) + #print(region_tags) + co_text_paragraph=[] + co_text_drop=[] + co_text_heading=[] + co_text_header=[] + co_text_marginalia=[] + co_text_catch=[] + co_text_page_number=[] + co_text_signature_mark=[] + co_sep=[] + co_img=[] + co_table=[] + co_graphic=[] + co_graphic_text_annotation=[] + co_graphic_decoration=[] + co_noise=[] + + + co_text_paragraph_text=[] + co_text_drop_text=[] + co_text_heading_text=[] + co_text_header_text=[] + co_text_marginalia_text=[] + co_text_catch_text=[] + co_text_page_number_text=[] + co_text_signature_mark_text=[] + co_sep_text=[] + co_img_text=[] + co_table_text=[] + co_graphic_text=[] + co_graphic_text_annotation_text=[] + co_graphic_decoration_text=[] + co_noise_text=[] + + + id_paragraph = [] + id_header = [] + id_heading = [] + id_marginalia = [] + + for tag in region_tags: + if tag.endswith('}TextRegion') or tag.endswith('}Textregion'): + for nn in root1.iter(tag): + for child2 in nn: + tag2 = child2.tag + #print(child2.tag) + if tag2.endswith('}TextEquiv') or tag2.endswith('}TextEquiv'): + #children2 = childtext.getchildren() + #rank = child2.find('Unicode').text + for childtext2 in child2: + #rank = childtext2.find('Unicode').text + #if childtext2.tag.endswith('}PlainText') or childtext2.tag.endswith('}PlainText'): + #print(childtext2.text) + if childtext2.tag.endswith('}Unicode') or childtext2.tag.endswith('}Unicode'): + if "type" in nn.attrib and nn.attrib['type']=='drop-capital': + co_text_drop_text.append(childtext2.text) + elif "type" in nn.attrib and nn.attrib['type']=='heading': + co_text_heading_text.append(childtext2.text) + elif "type" in nn.attrib and nn.attrib['type']=='signature-mark': + co_text_signature_mark_text.append(childtext2.text) + elif "type" in nn.attrib and nn.attrib['type']=='header': + co_text_header_text.append(childtext2.text) + elif "type" in nn.attrib and nn.attrib['type']=='catch-word': + co_text_catch_text.append(childtext2.text) + elif "type" in nn.attrib and nn.attrib['type']=='page-number': + co_text_page_number_text.append(childtext2.text) + elif "type" in nn.attrib and nn.attrib['type']=='marginalia': + co_text_marginalia_text.append(childtext2.text) + else: + co_text_paragraph_text.append(childtext2.text) + c_t_in_drop=[] + c_t_in_paragraph=[] + c_t_in_heading=[] + c_t_in_header=[] + c_t_in_page_number=[] + c_t_in_signature_mark=[] + c_t_in_catch=[] + c_t_in_marginalia=[] + + + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + + coords=bool(vv.attrib) + if coords: + #print('birda1') + p_h=vv.attrib['points'].split(' ') + + + + if "type" in nn.attrib and nn.attrib['type']=='drop-capital': + #if nn.attrib['type']=='paragraph': + + c_t_in_drop.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + elif "type" in nn.attrib and nn.attrib['type']=='heading': + id_heading.append(nn.attrib['id']) + c_t_in_heading.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + + elif "type" in nn.attrib and nn.attrib['type']=='signature-mark': + + c_t_in_signature_mark.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + #print(c_t_in_paragraph) + elif "type" in nn.attrib and nn.attrib['type']=='header': + id_header.append(nn.attrib['id']) + c_t_in_header.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + + elif "type" in nn.attrib and nn.attrib['type']=='catch-word': + c_t_in_catch.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + + elif "type" in nn.attrib and nn.attrib['type']=='page-number': + + c_t_in_page_number.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + #print(c_t_in_paragraph) + + elif "type" in nn.attrib and nn.attrib['type']=='marginalia': + id_marginalia.append(nn.attrib['id']) + + c_t_in_marginalia.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + #print(c_t_in_paragraph) + else: + #print(nn.attrib['id']) + + id_paragraph.append(nn.attrib['id']) + + c_t_in_paragraph.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + #print(c_t_in_paragraph) + + break + else: + pass + + + if vv.tag==link+'Point': + if "type" in nn.attrib and nn.attrib['type']=='drop-capital': + #if nn.attrib['type']=='paragraph': + + c_t_in_drop.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + elif "type" in nn.attrib and nn.attrib['type']=='heading': + id_heading.append(nn.attrib['id']) + c_t_in_heading.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + + elif "type" in nn.attrib and nn.attrib['type']=='signature-mark': + + c_t_in_signature_mark.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + #print(c_t_in_paragraph) + sumi+=1 + elif "type" in nn.attrib and nn.attrib['type']=='header': + id_header.append(nn.attrib['id']) + c_t_in_header.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + + elif "type" in nn.attrib and nn.attrib['type']=='catch-word': + c_t_in_catch.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + + elif "type" in nn.attrib and nn.attrib['type']=='page-number': + + c_t_in_page_number.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + #print(c_t_in_paragraph) + sumi+=1 + + elif "type" in nn.attrib and nn.attrib['type']=='marginalia': + id_marginalia.append(nn.attrib['id']) + + c_t_in_marginalia.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + #print(c_t_in_paragraph) + sumi+=1 + + else: + id_paragraph.append(nn.attrib['id']) + c_t_in_paragraph.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + #print(c_t_in_paragraph) + sumi+=1 + + #c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + + if len(c_t_in_drop)>0: + co_text_drop.append(np.array(c_t_in_drop)) + if len(c_t_in_paragraph)>0: + co_text_paragraph.append(np.array(c_t_in_paragraph)) + if len(c_t_in_heading)>0: + co_text_heading.append(np.array(c_t_in_heading)) + + if len(c_t_in_header)>0: + co_text_header.append(np.array(c_t_in_header)) + if len(c_t_in_page_number)>0: + co_text_page_number.append(np.array(c_t_in_page_number)) + if len(c_t_in_catch)>0: + co_text_catch.append(np.array(c_t_in_catch)) + + if len(c_t_in_signature_mark)>0: + co_text_signature_mark.append(np.array(c_t_in_signature_mark)) + + if len(c_t_in_marginalia)>0: + co_text_marginalia.append(np.array(c_t_in_marginalia)) + + + elif tag.endswith('}GraphicRegion') or tag.endswith('}graphicregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + c_t_in_text_annotation=[] + c_t_in_decoration=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + #c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + if "type" in nn.attrib and nn.attrib['type']=='handwritten-annotation': + #if nn.attrib['type']=='paragraph': + + c_t_in_text_annotation.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + elif "type" in nn.attrib and nn.attrib['type']=='decoration': + + c_t_in_decoration.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + #print(c_t_in_paragraph) + else: + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + + + break + else: + pass + + + if vv.tag==link+'Point': + + if "type" in nn.attrib and nn.attrib['type']=='handwritten-annotation': + #if nn.attrib['type']=='paragraph': + + c_t_in_text_annotation.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + elif "type" in nn.attrib and nn.attrib['type']=='decoration': + + c_t_in_decoration.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + #print(c_t_in_paragraph) + sumi+=1 + else: + c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + + if len(c_t_in_text_annotation)>0: + co_graphic_text_annotation.append(np.array(c_t_in_text_annotation)) + if len(c_t_in_decoration)>0: + co_graphic_decoration.append(np.array(c_t_in_decoration)) + if len(c_t_in)>0: + co_graphic.append(np.array(c_t_in)) + + + + elif tag.endswith('}ImageRegion') or tag.endswith('}imageregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_img.append(np.array(c_t_in)) + co_img_text.append(' ') + + + elif tag.endswith('}SeparatorRegion') or tag.endswith('}separatorregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_sep.append(np.array(c_t_in)) + + + + elif tag.endswith('}TableRegion') or tag.endswith('}tableregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_table.append(np.array(c_t_in)) + co_table_text.append(' ') + + elif tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_noise.append(np.array(c_t_in)) + co_noise_text.append(' ') + + + img = np.zeros( (y_len,x_len,3) ) + + img_poly=cv2.fillPoly(img, pts =co_text_paragraph, color=(1,1,1)) + + img_poly=cv2.fillPoly(img, pts =co_text_heading, color=(2,2,2)) + img_poly=cv2.fillPoly(img, pts =co_text_header, color=(2,2,2)) + #img_poly=cv2.fillPoly(img, pts =co_text_catch, color=(125,255,125)) + #img_poly=cv2.fillPoly(img, pts =co_text_signature_mark, color=(125,125,0)) + #img_poly=cv2.fillPoly(img, pts =co_graphic_decoration, color=(1,125,255)) + #img_poly=cv2.fillPoly(img, pts =co_text_page_number, color=(1,125,0)) + img_poly=cv2.fillPoly(img, pts =co_text_marginalia, color=(3,3,3)) + #img_poly=cv2.fillPoly(img, pts =co_text_drop, color=(1,125,255)) + + #img_poly=cv2.fillPoly(img, pts =co_graphic_text_annotation, color=(125,0,125)) + img_poly=cv2.fillPoly(img, pts =co_img, color=(4,4,4)) + img_poly=cv2.fillPoly(img, pts =co_sep, color=(5,5,5)) + #img_poly=cv2.fillPoly(img, pts =co_table, color=(1,255,255)) + #img_poly=cv2.fillPoly(img, pts =co_graphic, color=(255,125,125)) + #img_poly=cv2.fillPoly(img, pts =co_noise, color=(255,0,255)) + + #print('yazdimmm',self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.jpg') + ###try: + ####print('yazdimmm',self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.jpg') + ###cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('-')[1].split('.')[0]+'.jpg',img_poly ) + ###except: + ###cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.jpg',img_poly ) + return file_name, id_paragraph, id_header,co_text_paragraph, co_text_header,\ +tot_region_ref,x_len, y_len,index_tot_regions, img_poly + + + + +def bounding_box(cnt,color, corr_order_index ): + x, y, w, h = cv2.boundingRect(cnt) + x = int(x*scale_w) + y = int(y*scale_h) + + w = int(w*scale_w) + h = int(h*scale_h) + + return [x,y,w,h,int(color), int(corr_order_index)+1] + +def resize_image(seg_in,input_height,input_width): + return cv2.resize(seg_in,(input_width,input_height),interpolation=cv2.INTER_NEAREST) + +def make_image_from_bb(width_l, height_l, bb_all): + bb_all =np.array(bb_all) + img_remade = np.zeros((height_l,width_l )) + + for i in range(bb_all.shape[0]): + img_remade[bb_all[i,1]:bb_all[i,1]+bb_all[i,3],bb_all[i,0]:bb_all[i,0]+bb_all[i,2] ] = 1 + return img_remade diff --git a/train/pagexml2label.py b/train/pagexml2label.py deleted file mode 100644 index 94596db..0000000 --- a/train/pagexml2label.py +++ /dev/null @@ -1,789 +0,0 @@ -import click -import sys -import os -import numpy as np -import warnings -import xml.etree.ElementTree as ET -from tqdm import tqdm -import cv2 -from shapely import geometry -import json - -with warnings.catch_warnings(): - warnings.simplefilter("ignore") - -__doc__=\ -""" -tool to extract 2d or 3d RGB images from page xml data. In former case output will be 1 -2D image array which each class has filled with a pixel value. In the case of 3D RGB image -each class will be defined with a RGB value and beside images a text file of classes also will be produced. -This classes.txt file is required for dhsegment tool. -""" -KERNEL = np.ones((5, 5), np.uint8) - -class pagexml2label: - def __init__(self,dir_in, out_dir,output_type,config): - self.dir=dir_in - self.output_dir=out_dir - self.output_type=output_type - self.config=config - - def get_content_of_dir(self): - """ - Listing all ground truth page xml files. All files are needed to have xml format. - """ - - gt_all=os.listdir(self.dir) - self.gt_list=[file for file in gt_all if file.split('.')[ len(file.split('.'))-1 ]=='xml' ] - - def return_parent_contours(self,contours, hierarchy): - contours_parent = [contours[i] for i in range(len(contours)) if hierarchy[0][i][3] == -1] - return contours_parent - def filter_contours_area_of_image_tables(self,image, contours, hierarchy, max_area, min_area): - found_polygons_early = list() - - jv = 0 - for c in contours: - if len(c) < 3: # A polygon cannot have less than 3 points - continue - - polygon = geometry.Polygon([point[0] for point in c]) - # area = cv2.contourArea(c) - area = polygon.area - ##print(np.prod(thresh.shape[:2])) - # Check that polygon has area greater than minimal area - # print(hierarchy[0][jv][3],hierarchy ) - if area >= min_area * np.prod(image.shape[:2]) and area <= max_area * np.prod(image.shape[:2]): # and hierarchy[0][jv][3]==-1 : - # print(c[0][0][1]) - found_polygons_early.append(np.array([[point] for point in polygon.exterior.coords], dtype=np.int32)) - jv += 1 - return found_polygons_early - - def return_contours_of_interested_region(self,region_pre_p, pixel, min_area=0.0002): - - # pixels of images are identified by 5 - if len(region_pre_p.shape) == 3: - cnts_images = (region_pre_p[:, :, 0] == pixel) * 1 - else: - cnts_images = (region_pre_p[:, :] == pixel) * 1 - cnts_images = cnts_images.astype(np.uint8) - cnts_images = np.repeat(cnts_images[:, :, np.newaxis], 3, axis=2) - imgray = cv2.cvtColor(cnts_images, cv2.COLOR_BGR2GRAY) - ret, thresh = cv2.threshold(imgray, 0, 255, 0) - - contours_imgs, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) - - contours_imgs = self.return_parent_contours(contours_imgs, hierarchy) - contours_imgs = self.filter_contours_area_of_image_tables(thresh, contours_imgs, hierarchy, max_area=1, min_area=min_area) - - return contours_imgs - def update_region_contours(self, co_text, img_boundary, erosion_rate, dilation_rate, y_len, x_len): - co_text_eroded = [] - for con in co_text: - #try: - img_boundary_in = np.zeros( (y_len,x_len) ) - img_boundary_in = cv2.fillPoly(img_boundary_in, pts=[con], color=(1, 1, 1)) - #print('bidiahhhhaaa') - - - - #img_boundary_in = cv2.erode(img_boundary_in[:,:], KERNEL, iterations=7)#asiatica - if erosion_rate > 0: - img_boundary_in = cv2.erode(img_boundary_in[:,:], KERNEL, iterations=erosion_rate) - - pixel = 1 - min_size = 0 - con_eroded = self.return_contours_of_interested_region(img_boundary_in,pixel, min_size ) - - try: - co_text_eroded.append(con_eroded[0]) - except: - co_text_eroded.append(con) - - - img_boundary_in_dilated = cv2.dilate(img_boundary_in[:,:], KERNEL, iterations=dilation_rate) - #img_boundary_in_dilated = cv2.dilate(img_boundary_in[:,:], KERNEL, iterations=5) - - boundary = img_boundary_in_dilated[:,:] - img_boundary_in[:,:] - - img_boundary[:,:][boundary[:,:]==1] =1 - return co_text_eroded, img_boundary - def get_images_of_ground_truth(self, config_params): - """ - Reading the page xml files and write the ground truth images into given output directory. - """ - ## to do: add footnote to text regions - for index in tqdm(range(len(self.gt_list))): - #try: - tree1 = ET.parse(self.dir+'/'+self.gt_list[index]) - root1=tree1.getroot() - alltags=[elem.tag for elem in root1.iter()] - link=alltags[0].split('}')[0]+'}' - - - - for jj in root1.iter(link+'Page'): - y_len=int(jj.attrib['imageHeight']) - x_len=int(jj.attrib['imageWidth']) - - if self.config and (config_params['use_case']=='textline' or config_params['use_case']=='word' or config_params['use_case']=='glyph' or config_params['use_case']=='printspace'): - keys = list(config_params.keys()) - if "artificial_class_label" in keys: - artificial_class_rgb_color = (255,255,0) - artificial_class_label = config_params['artificial_class_label'] - - textline_rgb_color = (255, 0, 0) - - if config_params['use_case']=='textline': - region_tags = np.unique([x for x in alltags if x.endswith('TextLine')]) - elif config_params['use_case']=='word': - region_tags = np.unique([x for x in alltags if x.endswith('Word')]) - elif config_params['use_case']=='glyph': - region_tags = np.unique([x for x in alltags if x.endswith('Glyph')]) - elif config_params['use_case']=='printspace': - region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace')]) - - co_use_case = [] - - for tag in region_tags: - if config_params['use_case']=='textline': - tag_endings = ['}TextLine','}textline'] - elif config_params['use_case']=='word': - tag_endings = ['}Word','}word'] - elif config_params['use_case']=='glyph': - tag_endings = ['}Glyph','}glyph'] - elif config_params['use_case']=='printspace': - tag_endings = ['}PrintSpace','}printspace'] - - if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]): - for nn in root1.iter(tag): - c_t_in = [] - sumi = 0 - for vv in nn.iter(): - # check the format of coords - if vv.tag == link + 'Coords': - coords = bool(vv.attrib) - if coords: - p_h = vv.attrib['points'].split(' ') - c_t_in.append( - np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])) - break - else: - pass - - if vv.tag == link + 'Point': - c_t_in.append([int(np.float(vv.attrib['x'])), int(np.float(vv.attrib['y']))]) - sumi += 1 - elif vv.tag != link + 'Point' and sumi >= 1: - break - co_use_case.append(np.array(c_t_in)) - - - - if "artificial_class_label" in keys: - img_boundary = np.zeros((y_len, x_len)) - erosion_rate = 1 - dilation_rate = 3 - co_use_case, img_boundary = self.update_region_contours(co_use_case, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) - - - img = np.zeros((y_len, x_len, 3)) - if self.output_type == '2d': - img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1)) - if "artificial_class_label" in keys: - img_poly[:,:][img_boundary[:,:]==1] = artificial_class_label - elif self.output_type == '3d': - img_poly = cv2.fillPoly(img, pts=co_use_case, color=textline_rgb_color) - if "artificial_class_label" in keys: - img_poly[:,:,0][img_boundary[:,:]==1] = artificial_class_rgb_color[0] - img_poly[:,:,1][img_boundary[:,:]==1] = artificial_class_rgb_color[1] - img_poly[:,:,2][img_boundary[:,:]==1] = artificial_class_rgb_color[2] - - try: - cv2.imwrite(self.output_dir + '/' + self.gt_list[index].split('-')[1].split('.')[0] + '.png', - img_poly) - except: - cv2.imwrite(self.output_dir + '/' + self.gt_list[index].split('.')[0] + '.png', img_poly) - - - if self.config and config_params['use_case']=='layout': - keys = list(config_params.keys()) - if "artificial_class_on_boundry" in keys: - elements_with_artificial_class = list(config_params['artificial_class_on_boundry']) - artificial_class_rgb_color = (255,255,0) - artificial_class_label = config_params['artificial_class_label'] - #values = config_params.values() - - if 'textregions' in keys: - types_text_dict = config_params['textregions'] - types_text = list(types_text_dict.keys()) - types_text_label = list(types_text_dict.values()) - print(types_text) - if 'graphicregions' in keys: - types_graphic_dict = config_params['graphicregions'] - types_graphic = list(types_graphic_dict.keys()) - types_graphic_label = list(types_graphic_dict.values()) - - - labels_rgb_color = [ (0,0,0), (255,0,0), (255,125,0), (255,0,125), (125,255,125), (125,125,0), (0,125,255), (0,125,0), (125,125,125), (255,0,255), (125,0,125), (0,255,0),(0,0,255), (0,255,255), (255,125,125), (0,125,125), (0,255,125), (255,125,255), (125,255,0)] - - region_tags=np.unique([x for x in alltags if x.endswith('Region')]) - - co_text_paragraph=[] - co_text_footnote=[] - co_text_footnote_con=[] - co_text_drop=[] - co_text_heading=[] - co_text_header=[] - co_text_marginalia=[] - co_text_catch=[] - co_text_page_number=[] - co_text_signature_mark=[] - co_sep=[] - co_img=[] - co_table=[] - co_graphic_signature=[] - co_graphic_text_annotation=[] - co_graphic_decoration=[] - co_graphic_stamp=[] - co_noise=[] - - for tag in region_tags: - if 'textregions' in keys: - if tag.endswith('}TextRegion') or tag.endswith('}Textregion'): - for nn in root1.iter(tag): - c_t_in_drop=[] - c_t_in_paragraph=[] - c_t_in_heading=[] - c_t_in_header=[] - c_t_in_page_number=[] - c_t_in_signature_mark=[] - c_t_in_catch=[] - c_t_in_marginalia=[] - c_t_in_footnote=[] - c_t_in_footnote_con=[] - sumi=0 - for vv in nn.iter(): - # check the format of coords - if vv.tag==link+'Coords': - - coords=bool(vv.attrib) - if coords: - #print('birda1') - p_h=vv.attrib['points'].split(' ') - - if "drop-capital" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='drop-capital': - c_t_in_drop.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - if "footnote" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='footnote': - c_t_in_footnote.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - if "footnote-continued" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='footnote-continued': - c_t_in_footnote_con.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - if "heading" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='heading': - c_t_in_heading.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - if "signature-mark" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='signature-mark': - c_t_in_signature_mark.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - if "header" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='header': - c_t_in_header.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - if "catch-word" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='catch-word': - c_t_in_catch.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - if "page-number" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='page-number': - c_t_in_page_number.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - if "marginalia" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='marginalia': - c_t_in_marginalia.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - if "paragraph" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='paragraph': - c_t_in_paragraph.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - - break - else: - pass - - - if vv.tag==link+'Point': - if "drop-capital" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='drop-capital': - c_t_in_drop.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - if "footnote" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='footnote': - c_t_in_footnote.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - if "footnote-continued" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='footnote-continued': - c_t_in_footnote_con.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - if "heading" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='heading': - c_t_in_heading.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - if "signature-mark" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='signature-mark': - c_t_in_signature_mark.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - if "header" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='header': - c_t_in_header.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - if "catch-word" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='catch-word': - c_t_in_catch.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - if "page-number" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='page-number': - c_t_in_page_number.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - if "marginalia" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='marginalia': - c_t_in_marginalia.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - if "paragraph" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='paragraph': - c_t_in_paragraph.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - - elif vv.tag!=link+'Point' and sumi>=1: - break - - if len(c_t_in_drop)>0: - co_text_drop.append(np.array(c_t_in_drop)) - if len(c_t_in_footnote_con)>0: - co_text_footnote_con.append(np.array(c_t_in_footnote_con)) - if len(c_t_in_footnote)>0: - co_text_footnote.append(np.array(c_t_in_footnote)) - if len(c_t_in_paragraph)>0: - co_text_paragraph.append(np.array(c_t_in_paragraph)) - if len(c_t_in_heading)>0: - co_text_heading.append(np.array(c_t_in_heading)) - - if len(c_t_in_header)>0: - co_text_header.append(np.array(c_t_in_header)) - if len(c_t_in_page_number)>0: - co_text_page_number.append(np.array(c_t_in_page_number)) - if len(c_t_in_catch)>0: - co_text_catch.append(np.array(c_t_in_catch)) - - if len(c_t_in_signature_mark)>0: - co_text_signature_mark.append(np.array(c_t_in_signature_mark)) - - if len(c_t_in_marginalia)>0: - co_text_marginalia.append(np.array(c_t_in_marginalia)) - - - if 'graphicregions' in keys: - if tag.endswith('}GraphicRegion') or tag.endswith('}graphicregion'): - #print('sth') - for nn in root1.iter(tag): - c_t_in_stamp=[] - c_t_in_text_annotation=[] - c_t_in_decoration=[] - c_t_in_signature=[] - sumi=0 - for vv in nn.iter(): - # check the format of coords - if vv.tag==link+'Coords': - coords=bool(vv.attrib) - if coords: - p_h=vv.attrib['points'].split(' ') - if "handwritten-annotation" in types_graphic: - if "type" in nn.attrib and nn.attrib['type']=='handwritten-annotation': - c_t_in_text_annotation.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - if "decoration" in types_graphic: - if "type" in nn.attrib and nn.attrib['type']=='decoration': - c_t_in_decoration.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - if "stamp" in types_graphic: - if "type" in nn.attrib and nn.attrib['type']=='stamp': - c_t_in_stamp.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - if "signature" in types_graphic: - if "type" in nn.attrib and nn.attrib['type']=='signature': - c_t_in_signature.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - - - break - else: - pass - - - if vv.tag==link+'Point': - if "handwritten-annotation" in types_graphic: - if "type" in nn.attrib and nn.attrib['type']=='handwritten-annotation': - c_t_in_text_annotation.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - if "decoration" in types_graphic: - if "type" in nn.attrib and nn.attrib['type']=='decoration': - c_t_in_decoration.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - if "stamp" in types_graphic: - if "type" in nn.attrib and nn.attrib['type']=='stamp': - c_t_in_stamp.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - if "signature" in types_graphic: - if "type" in nn.attrib and nn.attrib['type']=='signature': - c_t_in_signature.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - if len(c_t_in_text_annotation)>0: - co_graphic_text_annotation.append(np.array(c_t_in_text_annotation)) - if len(c_t_in_decoration)>0: - co_graphic_decoration.append(np.array(c_t_in_decoration)) - if len(c_t_in_stamp)>0: - co_graphic_stamp.append(np.array(c_t_in_stamp)) - if len(c_t_in_signature)>0: - co_graphic_signature.append(np.array(c_t_in_signature)) - - if 'imageregion' in keys: - if tag.endswith('}ImageRegion') or tag.endswith('}imageregion'): - for nn in root1.iter(tag): - c_t_in=[] - sumi=0 - for vv in nn.iter(): - if vv.tag==link+'Coords': - coords=bool(vv.attrib) - if coords: - p_h=vv.attrib['points'].split(' ') - c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - break - else: - pass - - - if vv.tag==link+'Point': - c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - elif vv.tag!=link+'Point' and sumi>=1: - break - co_img.append(np.array(c_t_in)) - - - if 'separatorregion' in keys: - if tag.endswith('}SeparatorRegion') or tag.endswith('}separatorregion'): - for nn in root1.iter(tag): - c_t_in=[] - sumi=0 - for vv in nn.iter(): - # check the format of coords - if vv.tag==link+'Coords': - coords=bool(vv.attrib) - if coords: - p_h=vv.attrib['points'].split(' ') - c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - break - else: - pass - - - if vv.tag==link+'Point': - c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - elif vv.tag!=link+'Point' and sumi>=1: - break - co_sep.append(np.array(c_t_in)) - - - - if 'tableregion' in keys: - if tag.endswith('}TableRegion') or tag.endswith('}tableregion'): - #print('sth') - for nn in root1.iter(tag): - c_t_in=[] - sumi=0 - for vv in nn.iter(): - # check the format of coords - if vv.tag==link+'Coords': - coords=bool(vv.attrib) - if coords: - p_h=vv.attrib['points'].split(' ') - c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - break - else: - pass - - - if vv.tag==link+'Point': - c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - #print(vv.tag,'in') - elif vv.tag!=link+'Point' and sumi>=1: - break - co_table.append(np.array(c_t_in)) - - if 'noiseregion' in keys: - if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'): - #print('sth') - for nn in root1.iter(tag): - c_t_in=[] - sumi=0 - for vv in nn.iter(): - # check the format of coords - if vv.tag==link+'Coords': - coords=bool(vv.attrib) - if coords: - p_h=vv.attrib['points'].split(' ') - c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - break - else: - pass - - - if vv.tag==link+'Point': - c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - #print(vv.tag,'in') - elif vv.tag!=link+'Point' and sumi>=1: - break - co_noise.append(np.array(c_t_in)) - - if "artificial_class_on_boundry" in keys: - img_boundary = np.zeros( (y_len,x_len) ) - if "paragraph" in elements_with_artificial_class: - erosion_rate = 2 - dilation_rate = 4 - co_text_paragraph, img_boundary = self.update_region_contours(co_text_paragraph, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) - if "drop-capital" in elements_with_artificial_class: - erosion_rate = 0 - dilation_rate = 4 - co_text_drop, img_boundary = self.update_region_contours(co_text_drop, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) - if "catch-word" in elements_with_artificial_class: - erosion_rate = 0 - dilation_rate = 4 - co_text_catch, img_boundary = self.update_region_contours(co_text_catch, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) - if "page-number" in elements_with_artificial_class: - erosion_rate = 0 - dilation_rate = 4 - co_text_page_number, img_boundary = self.update_region_contours(co_text_page_number, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) - if "header" in elements_with_artificial_class: - erosion_rate = 1 - dilation_rate = 4 - co_text_header, img_boundary = self.update_region_contours(co_text_header, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) - if "heading" in elements_with_artificial_class: - erosion_rate = 1 - dilation_rate = 4 - co_text_heading, img_boundary = self.update_region_contours(co_text_heading, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) - if "signature-mark" in elements_with_artificial_class: - erosion_rate = 1 - dilation_rate = 4 - co_text_signature_mark, img_boundary = self.update_region_contours(co_text_signature_mark, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) - if "marginalia" in elements_with_artificial_class: - erosion_rate = 2 - dilation_rate = 4 - co_text_marginalia, img_boundary = self.update_region_contours(co_text_marginalia, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) - if "footnote" in elements_with_artificial_class: - erosion_rate = 2 - dilation_rate = 4 - co_text_footnote, img_boundary = self.update_region_contours(co_text_footnote, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) - if "footnote-continued" in elements_with_artificial_class: - erosion_rate = 2 - dilation_rate = 4 - co_text_footnote_con, img_boundary = self.update_region_contours(co_text_footnote_con, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) - - - - img = np.zeros( (y_len,x_len,3) ) - - if self.output_type == '3d': - - if 'graphicregions' in keys: - if "handwritten-annotation" in types_graphic: - img_poly=cv2.fillPoly(img, pts =co_graphic_text_annotation, color=labels_rgb_color[ config_params['graphicregions']['handwritten-annotation']]) - if "signature" in types_graphic: - img_poly=cv2.fillPoly(img, pts =co_graphic_signature, color=labels_rgb_color[ config_params['graphicregions']['signature']]) - if "decoration" in types_graphic: - img_poly=cv2.fillPoly(img, pts =co_graphic_decoration, color=labels_rgb_color[ config_params['graphicregions']['decoration']]) - if "stamp" in types_graphic: - img_poly=cv2.fillPoly(img, pts =co_graphic_stamp, color=labels_rgb_color[ config_params['graphicregions']['stamp']]) - - if 'imageregion' in keys: - img_poly=cv2.fillPoly(img, pts =co_img, color=labels_rgb_color[ config_params['imageregion']]) - if 'separatorregion' in keys: - img_poly=cv2.fillPoly(img, pts =co_sep, color=labels_rgb_color[ config_params['separatorregion']]) - if 'tableregion' in keys: - img_poly=cv2.fillPoly(img, pts =co_table, color=labels_rgb_color[ config_params['tableregion']]) - if 'noiseregion' in keys: - img_poly=cv2.fillPoly(img, pts =co_noise, color=labels_rgb_color[ config_params['noiseregion']]) - - if 'textregions' in keys: - if "paragraph" in types_text: - img_poly=cv2.fillPoly(img, pts =co_text_paragraph, color=labels_rgb_color[ config_params['textregions']['paragraph']]) - if "footnote" in types_text: - img_poly=cv2.fillPoly(img, pts =co_text_footnote, color=labels_rgb_color[ config_params['textregions']['footnote']]) - if "footnote-continued" in types_text: - img_poly=cv2.fillPoly(img, pts =co_text_footnote_con, color=labels_rgb_color[ config_params['textregions']['footnote-continued']]) - if "heading" in types_text: - img_poly=cv2.fillPoly(img, pts =co_text_heading, color=labels_rgb_color[ config_params['textregions']['heading']]) - if "header" in types_text: - img_poly=cv2.fillPoly(img, pts =co_text_header, color=labels_rgb_color[ config_params['textregions']['header']]) - if "catch-word" in types_text: - img_poly=cv2.fillPoly(img, pts =co_text_catch, color=labels_rgb_color[ config_params['textregions']['catch-word']]) - if "signature-mark" in types_text: - img_poly=cv2.fillPoly(img, pts =co_text_signature_mark, color=labels_rgb_color[ config_params['textregions']['signature-mark']]) - if "page-number" in types_text: - img_poly=cv2.fillPoly(img, pts =co_text_page_number, color=labels_rgb_color[ config_params['textregions']['page-number']]) - if "marginalia" in types_text: - img_poly=cv2.fillPoly(img, pts =co_text_marginalia, color=labels_rgb_color[ config_params['textregions']['marginalia']]) - if "drop-capital" in types_text: - img_poly=cv2.fillPoly(img, pts =co_text_drop, color=labels_rgb_color[ config_params['textregions']['drop-capital']]) - - if "artificial_class_on_boundry" in keys: - img_poly[:,:,0][img_boundary[:,:]==1] = artificial_class_rgb_color[0] - img_poly[:,:,1][img_boundary[:,:]==1] = artificial_class_rgb_color[1] - img_poly[:,:,2][img_boundary[:,:]==1] = artificial_class_rgb_color[2] - - - - - elif self.output_type == '2d': - if 'graphicregions' in keys: - if "handwritten-annotation" in types_graphic: - color_label = config_params['graphicregions']['handwritten-annotation'] - img_poly=cv2.fillPoly(img, pts =co_graphic_text_annotation, color=(color_label,color_label,color_label)) - if "signature" in types_graphic: - color_label = config_params['graphicregions']['signature'] - img_poly=cv2.fillPoly(img, pts =co_graphic_signature, color=(color_label,color_label,color_label)) - if "decoration" in types_graphic: - color_label = config_params['graphicregions']['decoration'] - img_poly=cv2.fillPoly(img, pts =co_graphic_decoration, color=(color_label,color_label,color_label)) - if "stamp" in types_graphic: - color_label = config_params['graphicregions']['stamp'] - img_poly=cv2.fillPoly(img, pts =co_graphic_stamp, color=(color_label,color_label,color_label)) - - if 'imageregion' in keys: - color_label = config_params['imageregion'] - img_poly=cv2.fillPoly(img, pts =co_img, color=(color_label,color_label,color_label)) - if 'separatorregion' in keys: - color_label = config_params['separatorregion'] - img_poly=cv2.fillPoly(img, pts =co_sep, color=(color_label,color_label,color_label)) - if 'tableregion' in keys: - color_label = config_params['tableregion'] - img_poly=cv2.fillPoly(img, pts =co_table, color=(color_label,color_label,color_label)) - if 'noiseregion' in keys: - color_label = config_params['noiseregion'] - img_poly=cv2.fillPoly(img, pts =co_noise, color=(color_label,color_label,color_label)) - - if 'textregions' in keys: - if "paragraph" in types_text: - color_label = config_params['textregions']['paragraph'] - img_poly=cv2.fillPoly(img, pts =co_text_paragraph, color=(color_label,color_label,color_label)) - if "footnote" in types_text: - color_label = config_params['textregions']['footnote'] - img_poly=cv2.fillPoly(img, pts =co_text_footnote, color=(color_label,color_label,color_label)) - if "footnote-continued" in types_text: - color_label = config_params['textregions']['footnote-continued'] - img_poly=cv2.fillPoly(img, pts =co_text_footnote_con, color=(color_label,color_label,color_label)) - if "heading" in types_text: - color_label = config_params['textregions']['heading'] - img_poly=cv2.fillPoly(img, pts =co_text_heading, color=(color_label,color_label,color_label)) - if "header" in types_text: - color_label = config_params['textregions']['header'] - img_poly=cv2.fillPoly(img, pts =co_text_header, color=(color_label,color_label,color_label)) - if "catch-word" in types_text: - color_label = config_params['textregions']['catch-word'] - img_poly=cv2.fillPoly(img, pts =co_text_catch, color=(color_label,color_label,color_label)) - if "signature-mark" in types_text: - color_label = config_params['textregions']['signature-mark'] - img_poly=cv2.fillPoly(img, pts =co_text_signature_mark, color=(color_label,color_label,color_label)) - if "page-number" in types_text: - color_label = config_params['textregions']['page-number'] - img_poly=cv2.fillPoly(img, pts =co_text_page_number, color=(color_label,color_label,color_label)) - if "marginalia" in types_text: - color_label = config_params['textregions']['marginalia'] - img_poly=cv2.fillPoly(img, pts =co_text_marginalia, color=(color_label,color_label,color_label)) - if "drop-capital" in types_text: - color_label = config_params['textregions']['drop-capital'] - img_poly=cv2.fillPoly(img, pts =co_text_drop, color=(color_label,color_label,color_label)) - - if "artificial_class_on_boundry" in keys: - img_poly[:,:][img_boundary[:,:]==1] = artificial_class_label - - - - - try: - cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('-')[1].split('.')[0]+'.png',img_poly ) - except: - cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.png',img_poly ) - - - def run(self,config_params): - self.get_content_of_dir() - self.get_images_of_ground_truth(config_params) - - -@click.command() -@click.option( - "--dir_xml", - "-dx", - help="directory of GT page-xml files", - type=click.Path(exists=True, file_okay=False), -) -@click.option( - "--dir_out", - "-do", - help="directory where ground truth images would be written", - type=click.Path(exists=True, file_okay=False), -) - -@click.option( - "--config", - "-cfg", - help="config file of prefered layout or use case.", - type=click.Path(exists=True, dir_okay=False), -) - -@click.option( - "--type_output", - "-to", - help="this defines how output should be. A 2d image array or a 3d image array encoded with RGB color. Just pass 2d or 3d. The file will be saved one directory up. 2D image array is 3d but only information of one channel would be enough since all channels have the same values.", -) - - -def main(dir_xml,dir_out,type_output,config): - if config: - with open(config) as f: - config_params = json.load(f) - else: - print("passed") - config_params = None - x=pagexml2label(dir_xml,dir_out,type_output, config) - x.run(config_params) -if __name__=="__main__": - main() - - - From 9638098ae7e5269a597a98937f3c239270575525 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Fri, 24 May 2024 16:39:48 +0200 Subject: [PATCH 055/123] machine based reading order training is integrated --- train/models.py | 55 +++++++++++++++++++++++++++++++++++++++++++++++++ train/train.py | 31 ++++++++++++++++++++++++++++ train/utils.py | 23 +++++++++++++++++++++ 3 files changed, 109 insertions(+) diff --git a/train/models.py b/train/models.py index 4cceacd..d852ac3 100644 --- a/train/models.py +++ b/train/models.py @@ -544,4 +544,59 @@ def resnet50_classifier(n_classes,input_height=224,input_width=224,weight_decay= + return model + +def machine_based_reading_order_model(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False): + assert input_height%32 == 0 + assert input_width%32 == 0 + + img_input = Input(shape=(input_height,input_width , 3 )) + + if IMAGE_ORDERING == 'channels_last': + bn_axis = 3 + else: + bn_axis = 1 + + x1 = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input) + x1 = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x1) + + x1 = BatchNormalization(axis=bn_axis, name='bn_conv1')(x1) + x1 = Activation('relu')(x1) + x1 = MaxPooling2D((3, 3) , data_format=IMAGE_ORDERING , strides=(2, 2))(x1) + + x1 = conv_block(x1, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) + x1 = identity_block(x1, 3, [64, 64, 256], stage=2, block='b') + x1 = identity_block(x1, 3, [64, 64, 256], stage=2, block='c') + + x1 = conv_block(x1, 3, [128, 128, 512], stage=3, block='a') + x1 = identity_block(x1, 3, [128, 128, 512], stage=3, block='b') + x1 = identity_block(x1, 3, [128, 128, 512], stage=3, block='c') + x1 = identity_block(x1, 3, [128, 128, 512], stage=3, block='d') + + x1 = conv_block(x1, 3, [256, 256, 1024], stage=4, block='a') + x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='b') + x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='c') + x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='d') + x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='e') + x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='f') + + x1 = conv_block(x1, 3, [512, 512, 2048], stage=5, block='a') + x1 = identity_block(x1, 3, [512, 512, 2048], stage=5, block='b') + x1 = identity_block(x1, 3, [512, 512, 2048], stage=5, block='c') + + if pretraining: + Model(img_input , x1).load_weights(resnet50_Weights_path) + + x1 = AveragePooling2D((7, 7), name='avg_pool1')(x1) + flattened = Flatten()(x1) + + o = Dense(256, activation='relu', name='fc512')(flattened) + o=Dropout(0.2)(o) + + o = Dense(256, activation='relu', name='fc512a')(o) + o=Dropout(0.2)(o) + + o = Dense(n_classes, activation='sigmoid', name='fc1000')(o) + model = Model(img_input , o) + return model diff --git a/train/train.py b/train/train.py index 78974d3..f338c78 100644 --- a/train/train.py +++ b/train/train.py @@ -313,4 +313,35 @@ def run(_config, n_classes, n_epochs, input_height, with open(os.path.join( os.path.join(dir_output,'model_best'), "config.json"), "w") as fp: json.dump(_config, fp) # encode dict into JSON + + elif task=='reading_order': + configuration() + model = machine_based_reading_order_model(n_classes,input_height,input_width,weight_decay,pretraining) + + dir_flow_train_imgs = os.path.join(dir_train, 'images') + dir_flow_train_labels = os.path.join(dir_train, 'labels') + + classes = os.listdir(dir_flow_train_labels) + num_rows =len(classes) + #ls_test = os.listdir(dir_flow_train_labels) + + #f1score_tot = [0] + indexer_start = 0 + opt = SGD(lr=0.01, momentum=0.9) + opt_adam = tf.keras.optimizers.Adam(learning_rate=0.0001) + model.compile(loss="binary_crossentropy", + optimizer = opt_adam,metrics=['accuracy']) + for i in range(n_epochs): + history = model.fit(generate_arrays_from_folder_reading_order(dir_flow_train_labels, dir_flow_train_imgs, n_batch, input_height, input_width, n_classes), steps_per_epoch=num_rows / n_batch, verbose=1) + model.save( os.path.join(dir_output,'model_'+str(i+indexer_start) )) + + with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp: + json.dump(_config, fp) # encode dict into JSON + ''' + if f1score>f1score_tot[0]: + f1score_tot[0] = f1score + model_dir = os.path.join(dir_out,'model_best') + model.save(model_dir) + ''' + diff --git a/train/utils.py b/train/utils.py index 271d977..a2e8a9c 100644 --- a/train/utils.py +++ b/train/utils.py @@ -268,6 +268,29 @@ def IoU(Yi, y_predi): #print("Mean IoU: {:4.3f}".format(mIoU)) return mIoU +def generate_arrays_from_folder_reading_order(classes_file_dir, modal_dir, batchsize, height, width, n_classes): + all_labels_files = os.listdir(classes_file_dir) + ret_x= np.zeros((batchsize, height, width, 3))#.astype(np.int16) + ret_y= np.zeros((batchsize, n_classes)).astype(np.int16) + batchcount = 0 + while True: + for i in all_labels_files: + file_name = i.split('.')[0] + img = cv2.imread(os.path.join(modal_dir,file_name+'.png')) + + label_class = int( np.load(os.path.join(classes_file_dir,i)) ) + + ret_x[batchcount, :,:,0] = img[:,:,0]/3.0 + ret_x[batchcount, :,:,2] = img[:,:,2]/3.0 + ret_x[batchcount, :,:,1] = img[:,:,1]/5.0 + + ret_y[batchcount, :] = label_class + batchcount+=1 + if batchcount>=batchsize: + yield (ret_x, ret_y) + ret_x= np.zeros((batchsize, height, width, 3))#.astype(np.int16) + ret_y= np.zeros((batchsize, n_classes)).astype(np.int16) + batchcount = 0 def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_classes, task='segmentation'): c = 0 From ccf520d3c73d7c1132509434a206ddb2d504b5c2 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Mon, 27 May 2024 17:23:49 +0200 Subject: [PATCH 056/123] adding rest_as_paragraph and rest_as_graphic to elements --- train/custom_config_page2label.json | 10 +- train/gt_gen_utils.py | 454 ++++++++++------------------ 2 files changed, 170 insertions(+), 294 deletions(-) diff --git a/train/custom_config_page2label.json b/train/custom_config_page2label.json index d6320fa..e4c02cb 100644 --- a/train/custom_config_page2label.json +++ b/train/custom_config_page2label.json @@ -1,9 +1,9 @@ { "use_case": "layout", -"textregions":{"paragraph":1, "heading": 2, "header":2,"drop-capital": 3, "marginalia":4 ,"page-number":1 , "catch-word":1 ,"footnote": 1, "footnote-continued": 1}, -"imageregion":5, -"separatorregion":6, -"graphicregions" :{"handwritten-annotation":5, "decoration": 5, "signature": 5, "stamp": 5}, -"artificial_class_on_boundry": ["paragraph","header", "heading", "marginalia", "page-number", "catch-word", "drop-capital","footnote", "footnote-continued"], +"textregions":{ "rest_as_paragraph": 1, "header":2 , "heading":2 , "marginalia":3 }, +"imageregion":4, +"separatorregion":5, +"graphicregions" :{"rest_as_decoration":6}, +"artificial_class_on_boundry": ["paragraph"], "artificial_class_label":7 } diff --git a/train/gt_gen_utils.py b/train/gt_gen_utils.py index 9862e29..9dc8377 100644 --- a/train/gt_gen_utils.py +++ b/train/gt_gen_utils.py @@ -180,7 +180,7 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ pass if vv.tag == link + 'Point': - c_t_in.append([int(np.float(vv.attrib['x'])), int(np.float(vv.attrib['y']))]) + c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))]) sumi += 1 elif vv.tag != link + 'Point' and sumi >= 1: break @@ -226,7 +226,6 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ types_text_dict = config_params['textregions'] types_text = list(types_text_dict.keys()) types_text_label = list(types_text_dict.values()) - print(types_text) if 'graphicregions' in keys: types_graphic_dict = config_params['graphicregions'] types_graphic = list(types_graphic_dict.keys()) @@ -235,41 +234,20 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ labels_rgb_color = [ (0,0,0), (255,0,0), (255,125,0), (255,0,125), (125,255,125), (125,125,0), (0,125,255), (0,125,0), (125,125,125), (255,0,255), (125,0,125), (0,255,0),(0,0,255), (0,255,255), (255,125,125), (0,125,125), (0,255,125), (255,125,255), (125,255,0)] + region_tags=np.unique([x for x in alltags if x.endswith('Region')]) - - co_text_paragraph=[] - co_text_footnote=[] - co_text_footnote_con=[] - co_text_drop=[] - co_text_heading=[] - co_text_header=[] - co_text_marginalia=[] - co_text_catch=[] - co_text_page_number=[] - co_text_signature_mark=[] + co_text = {'drop-capital':[], "footnote":[], "footnote-continued":[], "heading":[], "signature-mark":[], "header":[], "catch-word":[], "page-number":[], "marginalia":[], "paragraph":[]} + co_graphic = {"handwritten-annotation":[], "decoration":[], "stamp":[], "signature":[]} co_sep=[] co_img=[] co_table=[] - co_graphic_signature=[] - co_graphic_text_annotation=[] - co_graphic_decoration=[] - co_graphic_stamp=[] co_noise=[] for tag in region_tags: if 'textregions' in keys: if tag.endswith('}TextRegion') or tag.endswith('}Textregion'): for nn in root1.iter(tag): - c_t_in_drop=[] - c_t_in_paragraph=[] - c_t_in_heading=[] - c_t_in_header=[] - c_t_in_page_number=[] - c_t_in_signature_mark=[] - c_t_in_catch=[] - c_t_in_marginalia=[] - c_t_in_footnote=[] - c_t_in_footnote_con=[] + c_t_in = {'drop-capital':[], "footnote":[], "footnote-continued":[], "heading":[], "signature-mark":[], "header":[], "catch-word":[], "page-number":[], "marginalia":[], "paragraph":[]} sumi=0 for vv in nn.iter(): # check the format of coords @@ -277,143 +255,63 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ coords=bool(vv.attrib) if coords: - #print('birda1') p_h=vv.attrib['points'].split(' ') - if "drop-capital" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='drop-capital': - c_t_in_drop.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - if "footnote" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='footnote': - c_t_in_footnote.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - if "footnote-continued" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='footnote-continued': - c_t_in_footnote_con.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - if "heading" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='heading': - c_t_in_heading.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - if "signature-mark" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='signature-mark': - c_t_in_signature_mark.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - if "header" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='header': - c_t_in_header.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - if "catch-word" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='catch-word': - c_t_in_catch.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - if "page-number" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='page-number': - c_t_in_page_number.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - if "marginalia" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='marginalia': - c_t_in_marginalia.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - if "paragraph" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='paragraph': - c_t_in_paragraph.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - + if "rest_as_paragraph" in types_text: + types_text_without_paragraph = [element for element in types_text if element!='rest_as_paragraph' and element!='paragraph'] + if len(types_text_without_paragraph) == 0: + if "type" in nn.attrib: + c_t_in['paragraph'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + elif len(types_text_without_paragraph) >= 1: + if "type" in nn.attrib: + if nn.attrib['type'] in types_text_without_paragraph: + c_t_in[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + else: + c_t_in['paragraph'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + else: + if "type" in nn.attrib: + c_t_in[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) break else: pass - + if vv.tag==link+'Point': - if "drop-capital" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='drop-capital': - c_t_in_drop.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + if "rest_as_paragraph" in types_text: + types_text_without_paragraph = [element for element in types_text if element!='rest_as_paragraph' and element!='paragraph'] + if len(types_text_without_paragraph) == 0: + if "type" in nn.attrib: + c_t_in['paragraph'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + elif len(types_text_without_paragraph) >= 1: + if "type" in nn.attrib: + if nn.attrib['type'] in types_text_without_paragraph: + c_t_in[nn.attrib['type']].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + else: + c_t_in['paragraph'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + + else: + if "type" in nn.attrib: + c_t_in[nn.attrib['type']].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) sumi+=1 - - if "footnote" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='footnote': - c_t_in_footnote.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - if "footnote-continued" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='footnote-continued': - c_t_in_footnote_con.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - if "heading" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='heading': - c_t_in_heading.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - if "signature-mark" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='signature-mark': - c_t_in_signature_mark.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - if "header" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='header': - c_t_in_header.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - if "catch-word" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='catch-word': - c_t_in_catch.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - if "page-number" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='page-number': - c_t_in_page_number.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - if "marginalia" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='marginalia': - c_t_in_marginalia.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - if "paragraph" in types_text: - if "type" in nn.attrib and nn.attrib['type']=='paragraph': - c_t_in_paragraph.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - + elif vv.tag!=link+'Point' and sumi>=1: break - if len(c_t_in_drop)>0: - co_text_drop.append(np.array(c_t_in_drop)) - if len(c_t_in_footnote_con)>0: - co_text_footnote_con.append(np.array(c_t_in_footnote_con)) - if len(c_t_in_footnote)>0: - co_text_footnote.append(np.array(c_t_in_footnote)) - if len(c_t_in_paragraph)>0: - co_text_paragraph.append(np.array(c_t_in_paragraph)) - if len(c_t_in_heading)>0: - co_text_heading.append(np.array(c_t_in_heading)) - - if len(c_t_in_header)>0: - co_text_header.append(np.array(c_t_in_header)) - if len(c_t_in_page_number)>0: - co_text_page_number.append(np.array(c_t_in_page_number)) - if len(c_t_in_catch)>0: - co_text_catch.append(np.array(c_t_in_catch)) - - if len(c_t_in_signature_mark)>0: - co_text_signature_mark.append(np.array(c_t_in_signature_mark)) - - if len(c_t_in_marginalia)>0: - co_text_marginalia.append(np.array(c_t_in_marginalia)) - - + for element_text in list(c_t_in.keys()): + if len(c_t_in[element_text])>0: + co_text[element_text].append(np.array(c_t_in[element_text])) + if 'graphicregions' in keys: if tag.endswith('}GraphicRegion') or tag.endswith('}graphicregion'): #print('sth') for nn in root1.iter(tag): - c_t_in_stamp=[] - c_t_in_text_annotation=[] - c_t_in_decoration=[] - c_t_in_signature=[] + c_t_in_graphic = {"handwritten-annotation":[], "decoration":[], "stamp":[], "signature":[]} sumi=0 for vv in nn.iter(): # check the format of coords @@ -421,23 +319,22 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ coords=bool(vv.attrib) if coords: p_h=vv.attrib['points'].split(' ') - if "handwritten-annotation" in types_graphic: - if "type" in nn.attrib and nn.attrib['type']=='handwritten-annotation': - c_t_in_text_annotation.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - if "decoration" in types_graphic: - if "type" in nn.attrib and nn.attrib['type']=='decoration': - c_t_in_decoration.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - if "stamp" in types_graphic: - if "type" in nn.attrib and nn.attrib['type']=='stamp': - c_t_in_stamp.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - - if "signature" in types_graphic: - if "type" in nn.attrib and nn.attrib['type']=='signature': - c_t_in_signature.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - + if "rest_as_decoration" in types_graphic: + types_graphic_without_decoration = [element for element in types_graphic if element!='rest_as_decoration' and element!='decoration'] + if len(types_graphic_without_decoration) == 0: + if "type" in nn.attrib: + c_t_in_graphic['decoration'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + elif len(types_graphic_without_decoration) >= 1: + if "type" in nn.attrib: + if nn.attrib['type'] in types_graphic_without_decoration: + c_t_in_graphic[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + else: + c_t_in_graphic['decoration'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + else: + if "type" in nn.attrib: + c_t_in_graphic[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) break else: @@ -445,34 +342,33 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ if vv.tag==link+'Point': - if "handwritten-annotation" in types_graphic: - if "type" in nn.attrib and nn.attrib['type']=='handwritten-annotation': - c_t_in_text_annotation.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + if "rest_as_decoration" in types_graphic: + types_graphic_without_decoration = [element for element in types_graphic if element!='rest_as_decoration' and element!='decoration'] + if len(types_graphic_without_decoration) == 0: + if "type" in nn.attrib: + c_t_in_graphic['decoration'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + elif len(types_graphic_without_decoration) >= 1: + if "type" in nn.attrib: + if nn.attrib['type'] in types_graphic_without_decoration: + c_t_in_graphic[nn.attrib['type']].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + else: + c_t_in_graphic['decoration'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + + else: + if "type" in nn.attrib: + c_t_in_graphic[nn.attrib['type']].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) sumi+=1 - if "decoration" in types_graphic: - if "type" in nn.attrib and nn.attrib['type']=='decoration': - c_t_in_decoration.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - if "stamp" in types_graphic: - if "type" in nn.attrib and nn.attrib['type']=='stamp': - c_t_in_stamp.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 - - if "signature" in types_graphic: - if "type" in nn.attrib and nn.attrib['type']=='signature': - c_t_in_signature.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) - sumi+=1 + elif vv.tag!=link+'Point' and sumi>=1: + break + + for element_graphic in list(c_t_in_graphic.keys()): + if len(c_t_in_graphic[element_graphic])>0: + co_graphic[element_graphic].append(np.array(c_t_in_graphic[element_graphic])) - if len(c_t_in_text_annotation)>0: - co_graphic_text_annotation.append(np.array(c_t_in_text_annotation)) - if len(c_t_in_decoration)>0: - co_graphic_decoration.append(np.array(c_t_in_decoration)) - if len(c_t_in_stamp)>0: - co_graphic_stamp.append(np.array(c_t_in_stamp)) - if len(c_t_in_signature)>0: - co_graphic_signature.append(np.array(c_t_in_signature)) if 'imageregion' in keys: if tag.endswith('}ImageRegion') or tag.endswith('}imageregion'): @@ -491,7 +387,7 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ if vv.tag==link+'Point': - c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 elif vv.tag!=link+'Point' and sumi>=1: @@ -517,7 +413,7 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ if vv.tag==link+'Point': - c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 elif vv.tag!=link+'Point' and sumi>=1: @@ -545,7 +441,7 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ if vv.tag==link+'Point': - c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 #print(vv.tag,'in') elif vv.tag!=link+'Point' and sumi>=1: @@ -571,7 +467,7 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ if vv.tag==link+'Point': - c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 #print(vv.tag,'in') elif vv.tag!=link+'Point' and sumi>=1: @@ -583,59 +479,63 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ if "paragraph" in elements_with_artificial_class: erosion_rate = 2 dilation_rate = 4 - co_text_paragraph, img_boundary = update_region_contours(co_text_paragraph, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + co_text['paragraph'], img_boundary = update_region_contours(co_text['paragraph'], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) if "drop-capital" in elements_with_artificial_class: erosion_rate = 0 dilation_rate = 4 - co_text_drop, img_boundary = update_region_contours(co_text_drop, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + co_text["drop-capital"], img_boundary = update_region_contours(co_text["drop-capital"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) if "catch-word" in elements_with_artificial_class: erosion_rate = 0 dilation_rate = 4 - co_text_catch, img_boundary = update_region_contours(co_text_catch, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + co_text["catch-word"], img_boundary = update_region_contours(co_text["catch-word"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) if "page-number" in elements_with_artificial_class: erosion_rate = 0 dilation_rate = 4 - co_text_page_number, img_boundary = update_region_contours(co_text_page_number, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + co_text["page-number"], img_boundary = update_region_contours(co_text["page-number"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) if "header" in elements_with_artificial_class: erosion_rate = 1 dilation_rate = 4 - co_text_header, img_boundary = update_region_contours(co_text_header, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + co_text["header"], img_boundary = update_region_contours(co_text["header"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) if "heading" in elements_with_artificial_class: erosion_rate = 1 dilation_rate = 4 - co_text_heading, img_boundary = update_region_contours(co_text_heading, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + co_text["heading"], img_boundary = update_region_contours(co_text["heading"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) if "signature-mark" in elements_with_artificial_class: erosion_rate = 1 dilation_rate = 4 - co_text_signature_mark, img_boundary = update_region_contours(co_text_signature_mark, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + co_text["signature-mark"], img_boundary = update_region_contours(co_text["signature-mark"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) if "marginalia" in elements_with_artificial_class: erosion_rate = 2 dilation_rate = 4 - co_text_marginalia, img_boundary = update_region_contours(co_text_marginalia, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + co_text["marginalia"], img_boundary = update_region_contours(co_text["marginalia"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) if "footnote" in elements_with_artificial_class: erosion_rate = 2 dilation_rate = 4 - co_text_footnote, img_boundary = update_region_contours(co_text_footnote, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + co_text["footnote"], img_boundary = update_region_contours(co_text["footnote"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) if "footnote-continued" in elements_with_artificial_class: erosion_rate = 2 dilation_rate = 4 - co_text_footnote_con, img_boundary = update_region_contours(co_text_footnote_con, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + co_text["footnote-continued"], img_boundary = update_region_contours(co_text["footnote-continued"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) img = np.zeros( (y_len,x_len,3) ) if output_type == '3d': - if 'graphicregions' in keys: - if "handwritten-annotation" in types_graphic: - img_poly=cv2.fillPoly(img, pts =co_graphic_text_annotation, color=labels_rgb_color[ config_params['graphicregions']['handwritten-annotation']]) - if "signature" in types_graphic: - img_poly=cv2.fillPoly(img, pts =co_graphic_signature, color=labels_rgb_color[ config_params['graphicregions']['signature']]) - if "decoration" in types_graphic: - img_poly=cv2.fillPoly(img, pts =co_graphic_decoration, color=labels_rgb_color[ config_params['graphicregions']['decoration']]) - if "stamp" in types_graphic: - img_poly=cv2.fillPoly(img, pts =co_graphic_stamp, color=labels_rgb_color[ config_params['graphicregions']['stamp']]) + if 'rest_as_decoration' in types_graphic: + types_graphic[types_graphic=='rest_as_decoration'] = 'decoration' + for element_graphic in types_graphic: + if element_graphic == 'decoration': + color_label = labels_rgb_color[ config_params['graphicregions']['rest_as_decoration']] + else: + color_label = labels_rgb_color[ config_params['graphicregions'][element_graphic]] + img_poly=cv2.fillPoly(img, pts =co_graphic[element_graphic], color=color_label) + else: + for element_graphic in types_graphic: + color_label = labels_rgb_color[ config_params['graphicregions'][element_graphic]] + img_poly=cv2.fillPoly(img, pts =co_graphic[element_graphic], color=color_label) + if 'imageregion' in keys: img_poly=cv2.fillPoly(img, pts =co_img, color=labels_rgb_color[ config_params['imageregion']]) @@ -647,26 +547,19 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ img_poly=cv2.fillPoly(img, pts =co_noise, color=labels_rgb_color[ config_params['noiseregion']]) if 'textregions' in keys: - if "paragraph" in types_text: - img_poly=cv2.fillPoly(img, pts =co_text_paragraph, color=labels_rgb_color[ config_params['textregions']['paragraph']]) - if "footnote" in types_text: - img_poly=cv2.fillPoly(img, pts =co_text_footnote, color=labels_rgb_color[ config_params['textregions']['footnote']]) - if "footnote-continued" in types_text: - img_poly=cv2.fillPoly(img, pts =co_text_footnote_con, color=labels_rgb_color[ config_params['textregions']['footnote-continued']]) - if "heading" in types_text: - img_poly=cv2.fillPoly(img, pts =co_text_heading, color=labels_rgb_color[ config_params['textregions']['heading']]) - if "header" in types_text: - img_poly=cv2.fillPoly(img, pts =co_text_header, color=labels_rgb_color[ config_params['textregions']['header']]) - if "catch-word" in types_text: - img_poly=cv2.fillPoly(img, pts =co_text_catch, color=labels_rgb_color[ config_params['textregions']['catch-word']]) - if "signature-mark" in types_text: - img_poly=cv2.fillPoly(img, pts =co_text_signature_mark, color=labels_rgb_color[ config_params['textregions']['signature-mark']]) - if "page-number" in types_text: - img_poly=cv2.fillPoly(img, pts =co_text_page_number, color=labels_rgb_color[ config_params['textregions']['page-number']]) - if "marginalia" in types_text: - img_poly=cv2.fillPoly(img, pts =co_text_marginalia, color=labels_rgb_color[ config_params['textregions']['marginalia']]) - if "drop-capital" in types_text: - img_poly=cv2.fillPoly(img, pts =co_text_drop, color=labels_rgb_color[ config_params['textregions']['drop-capital']]) + if 'rest_as_paragraph' in types_text: + types_text[types_text=='rest_as_paragraph'] = 'paragraph' + for element_text in types_text: + if element_text == 'paragraph': + color_label = labels_rgb_color[ config_params['textregions']['rest_as_paragraph']] + else: + color_label = labels_rgb_color[ config_params['textregions'][element_text]] + img_poly=cv2.fillPoly(img, pts =co_text[element_text], color=color_label) + else: + for element_text in types_text: + color_label = labels_rgb_color[ config_params['textregions'][element_text]] + img_poly=cv2.fillPoly(img, pts =co_text[element_text], color=color_label) + if "artificial_class_on_boundry" in keys: img_poly[:,:,0][img_boundary[:,:]==1] = artificial_class_rgb_color[0] @@ -678,18 +571,19 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ elif output_type == '2d': if 'graphicregions' in keys: - if "handwritten-annotation" in types_graphic: - color_label = config_params['graphicregions']['handwritten-annotation'] - img_poly=cv2.fillPoly(img, pts =co_graphic_text_annotation, color=(color_label,color_label,color_label)) - if "signature" in types_graphic: - color_label = config_params['graphicregions']['signature'] - img_poly=cv2.fillPoly(img, pts =co_graphic_signature, color=(color_label,color_label,color_label)) - if "decoration" in types_graphic: - color_label = config_params['graphicregions']['decoration'] - img_poly=cv2.fillPoly(img, pts =co_graphic_decoration, color=(color_label,color_label,color_label)) - if "stamp" in types_graphic: - color_label = config_params['graphicregions']['stamp'] - img_poly=cv2.fillPoly(img, pts =co_graphic_stamp, color=(color_label,color_label,color_label)) + if 'rest_as_decoration' in types_graphic: + types_graphic[types_graphic=='rest_as_decoration'] = 'decoration' + for element_graphic in types_graphic: + if element_graphic == 'decoration': + color_label = config_params['graphicregions']['rest_as_decoration'] + else: + color_label = config_params['graphicregions'][element_graphic] + img_poly=cv2.fillPoly(img, pts =co_graphic[element_graphic], color=color_label) + else: + for element_graphic in types_graphic: + color_label = config_params['graphicregions'][element_graphic] + img_poly=cv2.fillPoly(img, pts =co_graphic[element_graphic], color=color_label) + if 'imageregion' in keys: color_label = config_params['imageregion'] @@ -705,36 +599,18 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ img_poly=cv2.fillPoly(img, pts =co_noise, color=(color_label,color_label,color_label)) if 'textregions' in keys: - if "paragraph" in types_text: - color_label = config_params['textregions']['paragraph'] - img_poly=cv2.fillPoly(img, pts =co_text_paragraph, color=(color_label,color_label,color_label)) - if "footnote" in types_text: - color_label = config_params['textregions']['footnote'] - img_poly=cv2.fillPoly(img, pts =co_text_footnote, color=(color_label,color_label,color_label)) - if "footnote-continued" in types_text: - color_label = config_params['textregions']['footnote-continued'] - img_poly=cv2.fillPoly(img, pts =co_text_footnote_con, color=(color_label,color_label,color_label)) - if "heading" in types_text: - color_label = config_params['textregions']['heading'] - img_poly=cv2.fillPoly(img, pts =co_text_heading, color=(color_label,color_label,color_label)) - if "header" in types_text: - color_label = config_params['textregions']['header'] - img_poly=cv2.fillPoly(img, pts =co_text_header, color=(color_label,color_label,color_label)) - if "catch-word" in types_text: - color_label = config_params['textregions']['catch-word'] - img_poly=cv2.fillPoly(img, pts =co_text_catch, color=(color_label,color_label,color_label)) - if "signature-mark" in types_text: - color_label = config_params['textregions']['signature-mark'] - img_poly=cv2.fillPoly(img, pts =co_text_signature_mark, color=(color_label,color_label,color_label)) - if "page-number" in types_text: - color_label = config_params['textregions']['page-number'] - img_poly=cv2.fillPoly(img, pts =co_text_page_number, color=(color_label,color_label,color_label)) - if "marginalia" in types_text: - color_label = config_params['textregions']['marginalia'] - img_poly=cv2.fillPoly(img, pts =co_text_marginalia, color=(color_label,color_label,color_label)) - if "drop-capital" in types_text: - color_label = config_params['textregions']['drop-capital'] - img_poly=cv2.fillPoly(img, pts =co_text_drop, color=(color_label,color_label,color_label)) + if 'rest_as_paragraph' in types_text: + types_text[types_text=='rest_as_paragraph'] = 'paragraph' + for element_text in types_text: + if element_text == 'paragraph': + color_label = config_params['textregions']['rest_as_paragraph'] + else: + color_label = config_params['textregions'][element_text] + img_poly=cv2.fillPoly(img, pts =co_text[element_text], color=color_label) + else: + for element_text in types_text: + color_label = config_params['textregions'][element_text] + img_poly=cv2.fillPoly(img, pts =co_text[element_text], color=color_label) if "artificial_class_on_boundry" in keys: img_poly[:,:][img_boundary[:,:]==1] = artificial_class_label @@ -947,51 +823,51 @@ def read_xml(xml_file): if "type" in nn.attrib and nn.attrib['type']=='drop-capital': #if nn.attrib['type']=='paragraph': - c_t_in_drop.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + c_t_in_drop.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 elif "type" in nn.attrib and nn.attrib['type']=='heading': id_heading.append(nn.attrib['id']) - c_t_in_heading.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + c_t_in_heading.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 elif "type" in nn.attrib and nn.attrib['type']=='signature-mark': - c_t_in_signature_mark.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + c_t_in_signature_mark.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) #print(c_t_in_paragraph) sumi+=1 elif "type" in nn.attrib and nn.attrib['type']=='header': id_header.append(nn.attrib['id']) - c_t_in_header.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + c_t_in_header.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 elif "type" in nn.attrib and nn.attrib['type']=='catch-word': - c_t_in_catch.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + c_t_in_catch.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 elif "type" in nn.attrib and nn.attrib['type']=='page-number': - c_t_in_page_number.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + c_t_in_page_number.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) #print(c_t_in_paragraph) sumi+=1 elif "type" in nn.attrib and nn.attrib['type']=='marginalia': id_marginalia.append(nn.attrib['id']) - c_t_in_marginalia.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + c_t_in_marginalia.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) #print(c_t_in_paragraph) sumi+=1 else: id_paragraph.append(nn.attrib['id']) - c_t_in_paragraph.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + c_t_in_paragraph.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) #print(c_t_in_paragraph) sumi+=1 - #c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + #c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) #print(vv.tag,'in') elif vv.tag!=link+'Point' and sumi>=1: @@ -1057,16 +933,16 @@ def read_xml(xml_file): if "type" in nn.attrib and nn.attrib['type']=='handwritten-annotation': #if nn.attrib['type']=='paragraph': - c_t_in_text_annotation.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + c_t_in_text_annotation.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 elif "type" in nn.attrib and nn.attrib['type']=='decoration': - c_t_in_decoration.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + c_t_in_decoration.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) #print(c_t_in_paragraph) sumi+=1 else: - c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 if len(c_t_in_text_annotation)>0: @@ -1096,7 +972,7 @@ def read_xml(xml_file): if vv.tag==link+'Point': - c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 #print(vv.tag,'in') elif vv.tag!=link+'Point' and sumi>=1: @@ -1123,7 +999,7 @@ def read_xml(xml_file): if vv.tag==link+'Point': - c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 #print(vv.tag,'in') elif vv.tag!=link+'Point' and sumi>=1: @@ -1150,7 +1026,7 @@ def read_xml(xml_file): if vv.tag==link+'Point': - c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 #print(vv.tag,'in') elif vv.tag!=link+'Point' and sumi>=1: @@ -1176,7 +1052,7 @@ def read_xml(xml_file): if vv.tag==link+'Point': - c_t_in.append([ int(np.float(vv.attrib['x'])) , int(np.float(vv.attrib['y'])) ]) + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 #print(vv.tag,'in') elif vv.tag!=link+'Point' and sumi>=1: From 467bbb2884e1b900e819370b1e88853c24d60e90 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Tue, 28 May 2024 10:01:17 +0200 Subject: [PATCH 057/123] pass degrading scales for image enhancement as a json file --- train/generate_gt_for_training.py | 16 ++++++++++------ train/scales_enhancement.json | 3 +++ 2 files changed, 13 insertions(+), 6 deletions(-) create mode 100644 train/scales_enhancement.json diff --git a/train/generate_gt_for_training.py b/train/generate_gt_for_training.py index e296029..2a2a776 100644 --- a/train/generate_gt_for_training.py +++ b/train/generate_gt_for_training.py @@ -64,13 +64,17 @@ def pagexml2label(dir_xml,dir_out,type_output,config): help="directory where original images will be written as labels.", type=click.Path(exists=True, file_okay=False), ) -def image_enhancement(dir_imgs, dir_out_images, dir_out_labels): - #dir_imgs = './training_data_sample_enhancement/images' - #dir_out_images = './training_data_sample_enhancement/images_gt' - #dir_out_labels = './training_data_sample_enhancement/labels_gt' - +@click.option( + "--scales", + "-scs", + help="json dictionary where the scales are written.", + type=click.Path(exists=True, dir_okay=False), +) +def image_enhancement(dir_imgs, dir_out_images, dir_out_labels, scales): ls_imgs = os.listdir(dir_imgs) - ls_scales = [ 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9] + with open(scales) as f: + scale_dict = json.load(f) + ls_scales = scale_dict['scales'] for img in tqdm(ls_imgs): img_name = img.split('.')[0] diff --git a/train/scales_enhancement.json b/train/scales_enhancement.json new file mode 100644 index 0000000..58034f0 --- /dev/null +++ b/train/scales_enhancement.json @@ -0,0 +1,3 @@ +{ + "scales" : [ 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9] +} From cc7577d2c121ca14180bbc732355e35d7be80af8 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Tue, 28 May 2024 10:14:16 +0200 Subject: [PATCH 058/123] min area size of text region passes as an argument for machine based reading order --- train/generate_gt_for_training.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/train/generate_gt_for_training.py b/train/generate_gt_for_training.py index 2a2a776..cf2b2a6 100644 --- a/train/generate_gt_for_training.py +++ b/train/generate_gt_for_training.py @@ -116,22 +116,28 @@ def image_enhancement(dir_imgs, dir_out_images, dir_out_labels, scales): @click.option( "--input_height", "-ih", - help="input_height", + help="input height", ) @click.option( "--input_width", "-iw", - help="input_width", + help="input width", +) +@click.option( + "--min_area_size", + "-min", + help="min area size of regions considered for reading order training.", ) -def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, input_height, input_width): +def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, input_height, input_width, min_area_size): xml_files_ind = os.listdir(dir_xml) input_height = int(input_height) input_width = int(input_width) + min_area = float(min_area_size) indexer_start= 0#55166 max_area = 1 - min_area = 0.0001 + #min_area = 0.0001 for ind_xml in tqdm(xml_files_ind): indexer = 0 From 4fb45a671114c8d44b100dd799e097a3b669c27a Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Tue, 28 May 2024 16:48:51 +0200 Subject: [PATCH 059/123] inference for reading order --- train/gt_gen_utils.py | 134 +++++++++-------------------- train/inference.py | 196 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 227 insertions(+), 103 deletions(-) diff --git a/train/gt_gen_utils.py b/train/gt_gen_utils.py index 9dc8377..0286ac7 100644 --- a/train/gt_gen_utils.py +++ b/train/gt_gen_utils.py @@ -38,11 +38,8 @@ def filter_contours_area_of_image_tables(image, contours, hierarchy, max_area, m polygon = geometry.Polygon([point[0] for point in c]) # area = cv2.contourArea(c) area = polygon.area - ##print(np.prod(thresh.shape[:2])) # Check that polygon has area greater than minimal area - # print(hierarchy[0][jv][3],hierarchy ) if area >= min_area * np.prod(image.shape[:2]) and area <= max_area * np.prod(image.shape[:2]): # and hierarchy[0][jv][3]==-1 : - # print(c[0][0][1]) found_polygons_early.append(np.array([[point] for point in polygon.exterior.coords], dtype=np.int32)) jv += 1 return found_polygons_early @@ -52,15 +49,12 @@ def filter_contours_area_of_image(image, contours, order_index, max_area, min_ar order_index_filtered = list() #jv = 0 for jv, c in enumerate(contours): - #print(len(c[0])) c = c[0] if len(c) < 3: # A polygon cannot have less than 3 points continue c_e = [point for point in c] - #print(c_e) polygon = geometry.Polygon(c_e) area = polygon.area - #print(area,'area') if area >= min_area * np.prod(image.shape[:2]) and area <= max_area * np.prod(image.shape[:2]): # and hierarchy[0][jv][3]==-1 : found_polygons_early.append(np.array([[point] for point in polygon.exterior.coords], dtype=np.uint)) order_index_filtered.append(order_index[jv]) @@ -88,12 +82,8 @@ def return_contours_of_interested_region(region_pre_p, pixel, min_area=0.0002): def update_region_contours(co_text, img_boundary, erosion_rate, dilation_rate, y_len, x_len): co_text_eroded = [] for con in co_text: - #try: img_boundary_in = np.zeros( (y_len,x_len) ) img_boundary_in = cv2.fillPoly(img_boundary_in, pts=[con], color=(1, 1, 1)) - #print('bidiahhhhaaa') - - #img_boundary_in = cv2.erode(img_boundary_in[:,:], KERNEL, iterations=7)#asiatica if erosion_rate > 0: @@ -626,8 +616,6 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ def find_new_features_of_contours(contours_main): - - #print(contours_main[0][0][:, 0]) areas_main = np.array([cv2.contourArea(contours_main[j]) for j in range(len(contours_main))]) M_main = [cv2.moments(contours_main[j]) for j in range(len(contours_main))] @@ -658,8 +646,6 @@ def find_new_features_of_contours(contours_main): y_min_main = np.array([np.min(contours_main[j][:, 1]) for j in range(len(contours_main))]) y_max_main = np.array([np.max(contours_main[j][:, 1]) for j in range(len(contours_main))]) - # dis_x=np.abs(x_max_main-x_min_main) - return cx_main, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, y_corr_x_min_from_argmin def read_xml(xml_file): file_name = Path(xml_file).stem @@ -675,13 +661,11 @@ def read_xml(xml_file): y_len=int(jj.attrib['imageHeight']) x_len=int(jj.attrib['imageWidth']) - for jj in root1.iter(link+'RegionRefIndexed'): index_tot_regions.append(jj.attrib['index']) tot_region_ref.append(jj.attrib['regionRef']) region_tags=np.unique([x for x in alltags if x.endswith('Region')]) - #print(region_tags) co_text_paragraph=[] co_text_drop=[] co_text_heading=[] @@ -698,7 +682,6 @@ def read_xml(xml_file): co_graphic_decoration=[] co_noise=[] - co_text_paragraph_text=[] co_text_drop_text=[] co_text_heading_text=[] @@ -715,7 +698,6 @@ def read_xml(xml_file): co_graphic_decoration_text=[] co_noise_text=[] - id_paragraph = [] id_header = [] id_heading = [] @@ -726,14 +708,8 @@ def read_xml(xml_file): for nn in root1.iter(tag): for child2 in nn: tag2 = child2.tag - #print(child2.tag) if tag2.endswith('}TextEquiv') or tag2.endswith('}TextEquiv'): - #children2 = childtext.getchildren() - #rank = child2.find('Unicode').text for childtext2 in child2: - #rank = childtext2.find('Unicode').text - #if childtext2.tag.endswith('}PlainText') or childtext2.tag.endswith('}PlainText'): - #print(childtext2.text) if childtext2.tag.endswith('}Unicode') or childtext2.tag.endswith('}Unicode'): if "type" in nn.attrib and nn.attrib['type']=='drop-capital': co_text_drop_text.append(childtext2.text) @@ -743,10 +719,10 @@ def read_xml(xml_file): co_text_signature_mark_text.append(childtext2.text) elif "type" in nn.attrib and nn.attrib['type']=='header': co_text_header_text.append(childtext2.text) - elif "type" in nn.attrib and nn.attrib['type']=='catch-word': - co_text_catch_text.append(childtext2.text) - elif "type" in nn.attrib and nn.attrib['type']=='page-number': - co_text_page_number_text.append(childtext2.text) + ###elif "type" in nn.attrib and nn.attrib['type']=='catch-word': + ###co_text_catch_text.append(childtext2.text) + ###elif "type" in nn.attrib and nn.attrib['type']=='page-number': + ###co_text_page_number_text.append(childtext2.text) elif "type" in nn.attrib and nn.attrib['type']=='marginalia': co_text_marginalia_text.append(childtext2.text) else: @@ -774,7 +750,6 @@ def read_xml(xml_file): if "type" in nn.attrib and nn.attrib['type']=='drop-capital': - #if nn.attrib['type']=='paragraph': c_t_in_drop.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) @@ -792,27 +767,22 @@ def read_xml(xml_file): c_t_in_header.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - elif "type" in nn.attrib and nn.attrib['type']=='catch-word': - c_t_in_catch.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + ###elif "type" in nn.attrib and nn.attrib['type']=='catch-word': + ###c_t_in_catch.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - elif "type" in nn.attrib and nn.attrib['type']=='page-number': + ###elif "type" in nn.attrib and nn.attrib['type']=='page-number': - c_t_in_page_number.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - #print(c_t_in_paragraph) + ###c_t_in_page_number.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) elif "type" in nn.attrib and nn.attrib['type']=='marginalia': id_marginalia.append(nn.attrib['id']) c_t_in_marginalia.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - #print(c_t_in_paragraph) else: - #print(nn.attrib['id']) - id_paragraph.append(nn.attrib['id']) c_t_in_paragraph.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - #print(c_t_in_paragraph) break else: @@ -821,7 +791,6 @@ def read_xml(xml_file): if vv.tag==link+'Point': if "type" in nn.attrib and nn.attrib['type']=='drop-capital': - #if nn.attrib['type']=='paragraph': c_t_in_drop.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 @@ -835,7 +804,6 @@ def read_xml(xml_file): elif "type" in nn.attrib and nn.attrib['type']=='signature-mark': c_t_in_signature_mark.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) - #print(c_t_in_paragraph) sumi+=1 elif "type" in nn.attrib and nn.attrib['type']=='header': id_header.append(nn.attrib['id']) @@ -843,33 +811,26 @@ def read_xml(xml_file): sumi+=1 - elif "type" in nn.attrib and nn.attrib['type']=='catch-word': - c_t_in_catch.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) - sumi+=1 + ###elif "type" in nn.attrib and nn.attrib['type']=='catch-word': + ###c_t_in_catch.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + ###sumi+=1 + ###elif "type" in nn.attrib and nn.attrib['type']=='page-number': - elif "type" in nn.attrib and nn.attrib['type']=='page-number': - - c_t_in_page_number.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) - #print(c_t_in_paragraph) - sumi+=1 + ###c_t_in_page_number.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + ###sumi+=1 elif "type" in nn.attrib and nn.attrib['type']=='marginalia': id_marginalia.append(nn.attrib['id']) c_t_in_marginalia.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) - #print(c_t_in_paragraph) sumi+=1 else: id_paragraph.append(nn.attrib['id']) c_t_in_paragraph.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) - #print(c_t_in_paragraph) sumi+=1 - #c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) - - #print(vv.tag,'in') elif vv.tag!=link+'Point' and sumi>=1: break @@ -895,7 +856,6 @@ def read_xml(xml_file): elif tag.endswith('}GraphicRegion') or tag.endswith('}graphicregion'): - #print('sth') for nn in root1.iter(tag): c_t_in=[] c_t_in_text_annotation=[] @@ -907,40 +867,31 @@ def read_xml(xml_file): coords=bool(vv.attrib) if coords: p_h=vv.attrib['points'].split(' ') - #c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) if "type" in nn.attrib and nn.attrib['type']=='handwritten-annotation': - #if nn.attrib['type']=='paragraph': - c_t_in_text_annotation.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - + elif "type" in nn.attrib and nn.attrib['type']=='decoration': - c_t_in_decoration.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - #print(c_t_in_paragraph) + else: c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - break else: pass if vv.tag==link+'Point': - if "type" in nn.attrib and nn.attrib['type']=='handwritten-annotation': - #if nn.attrib['type']=='paragraph': - c_t_in_text_annotation.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 elif "type" in nn.attrib and nn.attrib['type']=='decoration': - c_t_in_decoration.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) - #print(c_t_in_paragraph) sumi+=1 + else: c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 @@ -955,7 +906,6 @@ def read_xml(xml_file): elif tag.endswith('}ImageRegion') or tag.endswith('}imageregion'): - #print('sth') for nn in root1.iter(tag): c_t_in=[] sumi=0 @@ -974,7 +924,6 @@ def read_xml(xml_file): if vv.tag==link+'Point': c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 - #print(vv.tag,'in') elif vv.tag!=link+'Point' and sumi>=1: break co_img.append(np.array(c_t_in)) @@ -982,7 +931,6 @@ def read_xml(xml_file): elif tag.endswith('}SeparatorRegion') or tag.endswith('}separatorregion'): - #print('sth') for nn in root1.iter(tag): c_t_in=[] sumi=0 @@ -1001,7 +949,6 @@ def read_xml(xml_file): if vv.tag==link+'Point': c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 - #print(vv.tag,'in') elif vv.tag!=link+'Point' and sumi>=1: break co_sep.append(np.array(c_t_in)) @@ -1009,7 +956,6 @@ def read_xml(xml_file): elif tag.endswith('}TableRegion') or tag.endswith('}tableregion'): - #print('sth') for nn in root1.iter(tag): c_t_in=[] sumi=0 @@ -1028,14 +974,13 @@ def read_xml(xml_file): if vv.tag==link+'Point': c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 - #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: break co_table.append(np.array(c_t_in)) co_table_text.append(' ') elif tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'): - #print('sth') for nn in root1.iter(tag): c_t_in=[] sumi=0 @@ -1054,40 +999,22 @@ def read_xml(xml_file): if vv.tag==link+'Point': c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 - #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: break co_noise.append(np.array(c_t_in)) co_noise_text.append(' ') - img = np.zeros( (y_len,x_len,3) ) - img_poly=cv2.fillPoly(img, pts =co_text_paragraph, color=(1,1,1)) img_poly=cv2.fillPoly(img, pts =co_text_heading, color=(2,2,2)) img_poly=cv2.fillPoly(img, pts =co_text_header, color=(2,2,2)) - #img_poly=cv2.fillPoly(img, pts =co_text_catch, color=(125,255,125)) - #img_poly=cv2.fillPoly(img, pts =co_text_signature_mark, color=(125,125,0)) - #img_poly=cv2.fillPoly(img, pts =co_graphic_decoration, color=(1,125,255)) - #img_poly=cv2.fillPoly(img, pts =co_text_page_number, color=(1,125,0)) img_poly=cv2.fillPoly(img, pts =co_text_marginalia, color=(3,3,3)) - #img_poly=cv2.fillPoly(img, pts =co_text_drop, color=(1,125,255)) - - #img_poly=cv2.fillPoly(img, pts =co_graphic_text_annotation, color=(125,0,125)) img_poly=cv2.fillPoly(img, pts =co_img, color=(4,4,4)) img_poly=cv2.fillPoly(img, pts =co_sep, color=(5,5,5)) - #img_poly=cv2.fillPoly(img, pts =co_table, color=(1,255,255)) - #img_poly=cv2.fillPoly(img, pts =co_graphic, color=(255,125,125)) - #img_poly=cv2.fillPoly(img, pts =co_noise, color=(255,0,255)) - #print('yazdimmm',self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.jpg') - ###try: - ####print('yazdimmm',self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.jpg') - ###cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('-')[1].split('.')[0]+'.jpg',img_poly ) - ###except: - ###cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.jpg',img_poly ) - return file_name, id_paragraph, id_header,co_text_paragraph, co_text_header,\ + return tree1, root1, file_name, id_paragraph, id_header,co_text_paragraph, co_text_header,\ tot_region_ref,x_len, y_len,index_tot_regions, img_poly @@ -1113,3 +1040,24 @@ def make_image_from_bb(width_l, height_l, bb_all): for i in range(bb_all.shape[0]): img_remade[bb_all[i,1]:bb_all[i,1]+bb_all[i,3],bb_all[i,0]:bb_all[i,0]+bb_all[i,2] ] = 1 return img_remade + +def update_list_and_return_first_with_length_bigger_than_one(index_element_to_be_updated, innner_index_pr_pos, pr_list, pos_list,list_inp): + list_inp.pop(index_element_to_be_updated) + if len(pr_list)>0: + list_inp.insert(index_element_to_be_updated, pr_list) + else: + index_element_to_be_updated = index_element_to_be_updated -1 + + list_inp.insert(index_element_to_be_updated+1, [innner_index_pr_pos]) + if len(pos_list)>0: + list_inp.insert(index_element_to_be_updated+2, pos_list) + + len_all_elements = [len(i) for i in list_inp] + list_len_bigger_1 = np.where(np.array(len_all_elements)>1) + list_len_bigger_1 = list_len_bigger_1[0] + + if len(list_len_bigger_1)>0: + early_list_bigger_than_one = list_len_bigger_1[0] + else: + early_list_bigger_than_one = -20 + return list_inp, early_list_bigger_than_one diff --git a/train/inference.py b/train/inference.py index 94e318d..73b4ed8 100644 --- a/train/inference.py +++ b/train/inference.py @@ -11,13 +11,11 @@ from tensorflow.keras import layers import tensorflow.keras.losses from tensorflow.keras.layers import * from models import * +from gt_gen_utils import * import click import json from tensorflow.python.keras import backend as tensorflow_backend - - - - +import xml.etree.ElementTree as ET with warnings.catch_warnings(): @@ -29,7 +27,7 @@ Tool to load model and predict for given image. """ class sbb_predict: - def __init__(self,image, model, task, config_params_model, patches, save, ground_truth): + def __init__(self,image, model, task, config_params_model, patches, save, ground_truth, xml_file): self.image=image self.patches=patches self.save=save @@ -37,6 +35,7 @@ class sbb_predict: self.ground_truth=ground_truth self.task=task self.config_params_model=config_params_model + self.xml_file = xml_file def resize_image(self,img_in,input_height,input_width): return cv2.resize( img_in, ( input_width,input_height) ,interpolation=cv2.INTER_NEAREST) @@ -166,7 +165,7 @@ class sbb_predict: ##if self.weights_dir!=None: ##self.model.load_weights(self.weights_dir) - if self.task != 'classification': + if (self.task != 'classification' and self.task != 'reading_order'): self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1] self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2] self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3] @@ -233,6 +232,178 @@ class sbb_predict: index_class = np.argmax(label_p_pred[0]) print("Predicted Class: {}".format(classes_names[str(int(index_class))])) + elif self.task == 'reading_order': + img_height = self.config_params_model['input_height'] + img_width = self.config_params_model['input_width'] + + tree_xml, root_xml, file_name, id_paragraph, id_header, co_text_paragraph, co_text_header, tot_region_ref, x_len, y_len, index_tot_regions, img_poly = read_xml(self.xml_file) + _, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, _ = find_new_features_of_contours(co_text_header) + + img_header_and_sep = np.zeros((y_len,x_len), dtype='uint8') + + for j in range(len(cy_main)): + img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12,int(x_min_main[j]):int(x_max_main[j]) ] = 1 + + co_text_all = co_text_paragraph + co_text_header + id_all_text = id_paragraph + id_header + + ##texts_corr_order_index = [index_tot_regions[tot_region_ref.index(i)] for i in id_all_text ] + ##texts_corr_order_index_int = [int(x) for x in texts_corr_order_index] + texts_corr_order_index_int = list(np.array(range(len(co_text_all)))) + + min_area = 0 + max_area = 1 + + co_text_all, texts_corr_order_index_int = filter_contours_area_of_image(img_poly, co_text_all, texts_corr_order_index_int, max_area, min_area) + + labels_con = np.zeros((y_len,x_len,len(co_text_all)),dtype='uint8') + for i in range(len(co_text_all)): + img_label = np.zeros((y_len,x_len,3),dtype='uint8') + img_label=cv2.fillPoly(img_label, pts =[co_text_all[i]], color=(1,1,1)) + labels_con[:,:,i] = img_label[:,:,0] + + img3= np.copy(img_poly) + labels_con = resize_image(labels_con, img_height, img_width) + + img_header_and_sep = resize_image(img_header_and_sep, img_height, img_width) + + img3= resize_image (img3, img_height, img_width) + img3 = img3.astype(np.uint16) + + inference_bs = 1#4 + + input_1= np.zeros( (inference_bs, img_height, img_width,3)) + + + starting_list_of_regions = [] + starting_list_of_regions.append( list(range(labels_con.shape[2])) ) + + index_update = 0 + index_selected = starting_list_of_regions[0] + + scalibility_num = 0 + while index_update>=0: + ij_list = starting_list_of_regions[index_update] + i = ij_list[0] + ij_list.pop(0) + + + pr_list = [] + post_list = [] + + batch_counter = 0 + tot_counter = 1 + + tot_iteration = len(ij_list) + full_bs_ite= tot_iteration//inference_bs + last_bs = tot_iteration % inference_bs + + jbatch_indexer =[] + for j in ij_list: + img1= np.repeat(labels_con[:,:,i][:, :, np.newaxis], 3, axis=2) + img2 = np.repeat(labels_con[:,:,j][:, :, np.newaxis], 3, axis=2) + + + img2[:,:,0][img3[:,:,0]==5] = 2 + img2[:,:,0][img_header_and_sep[:,:]==1] = 3 + + + + img1[:,:,0][img3[:,:,0]==5] = 2 + img1[:,:,0][img_header_and_sep[:,:]==1] = 3 + + #input_1= np.zeros( (height1, width1,3)) + + + jbatch_indexer.append(j) + + input_1[batch_counter,:,:,0] = img1[:,:,0]/3. + input_1[batch_counter,:,:,2] = img2[:,:,0]/3. + input_1[batch_counter,:,:,1] = img3[:,:,0]/5. + #input_1[batch_counter,:,:,:]= np.zeros( (batch_counter, height1, width1,3)) + batch_counter = batch_counter+1 + + #input_1[:,:,0] = img1[:,:,0]/3. + #input_1[:,:,2] = img2[:,:,0]/3. + #input_1[:,:,1] = img3[:,:,0]/5. + + if batch_counter==inference_bs or ( (tot_counter//inference_bs)==full_bs_ite and tot_counter%inference_bs==last_bs): + y_pr = self.model.predict(input_1 , verbose=0) + scalibility_num = scalibility_num+1 + + if batch_counter==inference_bs: + iteration_batches = inference_bs + else: + iteration_batches = last_bs + for jb in range(iteration_batches): + if y_pr[jb][0]>=0.5: + post_list.append(jbatch_indexer[jb]) + else: + pr_list.append(jbatch_indexer[jb]) + + batch_counter = 0 + jbatch_indexer = [] + + tot_counter = tot_counter+1 + + starting_list_of_regions, index_update = update_list_and_return_first_with_length_bigger_than_one(index_update, i, pr_list, post_list,starting_list_of_regions) + + index_sort = [i[0] for i in starting_list_of_regions ] + + + alltags=[elem.tag for elem in root_xml.iter()] + + + + link=alltags[0].split('}')[0]+'}' + name_space = alltags[0].split('}')[0] + name_space = name_space.split('{')[1] + + page_element = root_xml.find(link+'Page') + + """ + ro_subelement = ET.SubElement(page_element, 'ReadingOrder') + #print(page_element, 'page_element') + + #new_element = ET.Element('ReadingOrder') + + new_element_element = ET.Element('OrderedGroup') + new_element_element.set('id', "ro357564684568544579089") + + for index, id_text in enumerate(id_all_text): + new_element_2 = ET.Element('RegionRefIndexed') + new_element_2.set('regionRef', id_all_text[index]) + new_element_2.set('index', str(index_sort[index])) + + new_element_element.append(new_element_2) + + ro_subelement.append(new_element_element) + """ + ##ro_subelement = ET.SubElement(page_element, 'ReadingOrder') + + ro_subelement = ET.Element('ReadingOrder') + + ro_subelement2 = ET.SubElement(ro_subelement, 'OrderedGroup') + ro_subelement2.set('id', "ro357564684568544579089") + + for index, id_text in enumerate(id_all_text): + new_element_2 = ET.SubElement(ro_subelement2, 'RegionRefIndexed') + new_element_2.set('regionRef', id_all_text[index]) + new_element_2.set('index', str(index_sort[index])) + + if link+'PrintSpace' in alltags: + page_element.insert(1, ro_subelement) + else: + page_element.insert(0, ro_subelement) + + #page_element[0].append(new_element) + #root_xml.append(new_element) + alltags=[elem.tag for elem in root_xml.iter()] + + ET.register_namespace("",name_space) + tree_xml.write('library2.xml',xml_declaration=True,method='xml',encoding="utf8",default_namespace=None) + #tree_xml.write('library2.xml') + else: if self.patches: #def textline_contours(img,input_width,input_height,n_classes,model): @@ -356,7 +527,7 @@ class sbb_predict: def run(self): res=self.predict() - if self.task == 'classification': + if (self.task == 'classification' or self.task == 'reading_order'): pass else: img_seg_overlayed = self.visualize_model_output(res, self.img_org, self.task) @@ -397,15 +568,20 @@ class sbb_predict: "-gt", help="ground truth directory if you want to see the iou of prediction.", ) -def main(image, model, patches, save, ground_truth): +@click.option( + "--xml_file", + "-xml", + help="xml file with layout coordinates that reading order detection will be implemented on. The result will be written in the same xml file.", +) +def main(image, model, patches, save, ground_truth, xml_file): with open(os.path.join(model,'config.json')) as f: config_params_model = json.load(f) task = config_params_model['task'] - if task != 'classification': + if (task != 'classification' and task != 'reading_order'): if not save: print("Error: You used one of segmentation or binarization task but not set -s, you need a filename to save visualized output with -s") sys.exit(1) - x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth) + x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth, xml_file) x.run() if __name__=="__main__": From 06ed00619399fb93d48bd803f4bd66ba942d4d84 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Wed, 29 May 2024 11:18:35 +0200 Subject: [PATCH 060/123] reading order detection on xml with layout + result will be written in an output directory with the same file name --- train/gt_gen_utils.py | 74 +++++++++++++++++++++++++++++++++++++------ train/inference.py | 45 +++++++++++++++++++------- 2 files changed, 99 insertions(+), 20 deletions(-) diff --git a/train/gt_gen_utils.py b/train/gt_gen_utils.py index 0286ac7..8f72fb8 100644 --- a/train/gt_gen_utils.py +++ b/train/gt_gen_utils.py @@ -664,6 +664,58 @@ def read_xml(xml_file): for jj in root1.iter(link+'RegionRefIndexed'): index_tot_regions.append(jj.attrib['index']) tot_region_ref.append(jj.attrib['regionRef']) + + if (link+'PrintSpace' in alltags) or (link+'Border' in alltags): + co_printspace = [] + if link+'PrintSpace' in alltags: + region_tags_printspace = np.unique([x for x in alltags if x.endswith('PrintSpace')]) + elif link+'Border' in alltags: + region_tags_printspace = np.unique([x for x in alltags if x.endswith('Border')]) + + for tag in region_tags_printspace: + if link+'PrintSpace' in alltags: + tag_endings_printspace = ['}PrintSpace','}printspace'] + elif link+'Border' in alltags: + tag_endings_printspace = ['}Border','}border'] + + if tag.endswith(tag_endings_printspace[0]) or tag.endswith(tag_endings_printspace[1]): + for nn in root1.iter(tag): + c_t_in = [] + sumi = 0 + for vv in nn.iter(): + # check the format of coords + if vv.tag == link + 'Coords': + coords = bool(vv.attrib) + if coords: + p_h = vv.attrib['points'].split(' ') + c_t_in.append( + np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])) + break + else: + pass + + if vv.tag == link + 'Point': + c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))]) + sumi += 1 + elif vv.tag != link + 'Point' and sumi >= 1: + break + co_printspace.append(np.array(c_t_in)) + img_printspace = np.zeros( (y_len,x_len,3) ) + img_printspace=cv2.fillPoly(img_printspace, pts =co_printspace, color=(1,1,1)) + img_printspace = img_printspace.astype(np.uint8) + + imgray = cv2.cvtColor(img_printspace, cv2.COLOR_BGR2GRAY) + _, thresh = cv2.threshold(imgray, 0, 255, 0) + contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + cnt_size = np.array([cv2.contourArea(contours[j]) for j in range(len(contours))]) + cnt = contours[np.argmax(cnt_size)] + x, y, w, h = cv2.boundingRect(cnt) + + bb_coord_printspace = [x, y, w, h] + + else: + bb_coord_printspace = None + region_tags=np.unique([x for x in alltags if x.endswith('Region')]) co_text_paragraph=[] @@ -754,7 +806,7 @@ def read_xml(xml_file): c_t_in_drop.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) elif "type" in nn.attrib and nn.attrib['type']=='heading': - id_heading.append(nn.attrib['id']) + ##id_heading.append(nn.attrib['id']) c_t_in_heading.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) @@ -763,7 +815,7 @@ def read_xml(xml_file): c_t_in_signature_mark.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) #print(c_t_in_paragraph) elif "type" in nn.attrib and nn.attrib['type']=='header': - id_header.append(nn.attrib['id']) + #id_header.append(nn.attrib['id']) c_t_in_header.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) @@ -776,11 +828,11 @@ def read_xml(xml_file): ###c_t_in_page_number.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) elif "type" in nn.attrib and nn.attrib['type']=='marginalia': - id_marginalia.append(nn.attrib['id']) + #id_marginalia.append(nn.attrib['id']) c_t_in_marginalia.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) else: - id_paragraph.append(nn.attrib['id']) + #id_paragraph.append(nn.attrib['id']) c_t_in_paragraph.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) @@ -796,7 +848,7 @@ def read_xml(xml_file): sumi+=1 elif "type" in nn.attrib and nn.attrib['type']=='heading': - id_heading.append(nn.attrib['id']) + #id_heading.append(nn.attrib['id']) c_t_in_heading.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 @@ -806,7 +858,7 @@ def read_xml(xml_file): c_t_in_signature_mark.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 elif "type" in nn.attrib and nn.attrib['type']=='header': - id_header.append(nn.attrib['id']) + #id_header.append(nn.attrib['id']) c_t_in_header.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 @@ -821,13 +873,13 @@ def read_xml(xml_file): ###sumi+=1 elif "type" in nn.attrib and nn.attrib['type']=='marginalia': - id_marginalia.append(nn.attrib['id']) + #id_marginalia.append(nn.attrib['id']) c_t_in_marginalia.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 else: - id_paragraph.append(nn.attrib['id']) + #id_paragraph.append(nn.attrib['id']) c_t_in_paragraph.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 @@ -838,11 +890,14 @@ def read_xml(xml_file): co_text_drop.append(np.array(c_t_in_drop)) if len(c_t_in_paragraph)>0: co_text_paragraph.append(np.array(c_t_in_paragraph)) + id_paragraph.append(nn.attrib['id']) if len(c_t_in_heading)>0: co_text_heading.append(np.array(c_t_in_heading)) + id_heading.append(nn.attrib['id']) if len(c_t_in_header)>0: co_text_header.append(np.array(c_t_in_header)) + id_header.append(nn.attrib['id']) if len(c_t_in_page_number)>0: co_text_page_number.append(np.array(c_t_in_page_number)) if len(c_t_in_catch)>0: @@ -853,6 +908,7 @@ def read_xml(xml_file): if len(c_t_in_marginalia)>0: co_text_marginalia.append(np.array(c_t_in_marginalia)) + id_marginalia.append(nn.attrib['id']) elif tag.endswith('}GraphicRegion') or tag.endswith('}graphicregion'): @@ -1014,7 +1070,7 @@ def read_xml(xml_file): img_poly=cv2.fillPoly(img, pts =co_img, color=(4,4,4)) img_poly=cv2.fillPoly(img, pts =co_sep, color=(5,5,5)) - return tree1, root1, file_name, id_paragraph, id_header,co_text_paragraph, co_text_header,\ + return tree1, root1, bb_coord_printspace, file_name, id_paragraph, id_header+id_heading, co_text_paragraph, co_text_header+co_text_heading,\ tot_region_ref,x_len, y_len,index_tot_regions, img_poly diff --git a/train/inference.py b/train/inference.py index 73b4ed8..28445e8 100644 --- a/train/inference.py +++ b/train/inference.py @@ -16,6 +16,7 @@ import click import json from tensorflow.python.keras import backend as tensorflow_backend import xml.etree.ElementTree as ET +import matplotlib.pyplot as plt with warnings.catch_warnings(): @@ -27,7 +28,7 @@ Tool to load model and predict for given image. """ class sbb_predict: - def __init__(self,image, model, task, config_params_model, patches, save, ground_truth, xml_file): + def __init__(self,image, model, task, config_params_model, patches, save, ground_truth, xml_file, out): self.image=image self.patches=patches self.save=save @@ -36,6 +37,7 @@ class sbb_predict: self.task=task self.config_params_model=config_params_model self.xml_file = xml_file + self.out = out def resize_image(self,img_in,input_height,input_width): return cv2.resize( img_in, ( input_width,input_height) ,interpolation=cv2.INTER_NEAREST) @@ -236,16 +238,18 @@ class sbb_predict: img_height = self.config_params_model['input_height'] img_width = self.config_params_model['input_width'] - tree_xml, root_xml, file_name, id_paragraph, id_header, co_text_paragraph, co_text_header, tot_region_ref, x_len, y_len, index_tot_regions, img_poly = read_xml(self.xml_file) + tree_xml, root_xml, bb_coord_printspace, file_name, id_paragraph, id_header, co_text_paragraph, co_text_header, tot_region_ref, x_len, y_len, index_tot_regions, img_poly = read_xml(self.xml_file) _, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, _ = find_new_features_of_contours(co_text_header) img_header_and_sep = np.zeros((y_len,x_len), dtype='uint8') + for j in range(len(cy_main)): img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12,int(x_min_main[j]):int(x_max_main[j]) ] = 1 co_text_all = co_text_paragraph + co_text_header id_all_text = id_paragraph + id_header + ##texts_corr_order_index = [index_tot_regions[tot_region_ref.index(i)] for i in id_all_text ] ##texts_corr_order_index_int = [int(x) for x in texts_corr_order_index] @@ -253,8 +257,9 @@ class sbb_predict: min_area = 0 max_area = 1 + - co_text_all, texts_corr_order_index_int = filter_contours_area_of_image(img_poly, co_text_all, texts_corr_order_index_int, max_area, min_area) + ##co_text_all, texts_corr_order_index_int = filter_contours_area_of_image(img_poly, co_text_all, texts_corr_order_index_int, max_area, min_area) labels_con = np.zeros((y_len,x_len,len(co_text_all)),dtype='uint8') for i in range(len(co_text_all)): @@ -262,6 +267,18 @@ class sbb_predict: img_label=cv2.fillPoly(img_label, pts =[co_text_all[i]], color=(1,1,1)) labels_con[:,:,i] = img_label[:,:,0] + if bb_coord_printspace: + #bb_coord_printspace[x,y,w,h,_,_] + x = bb_coord_printspace[0] + y = bb_coord_printspace[1] + w = bb_coord_printspace[2] + h = bb_coord_printspace[3] + labels_con = labels_con[y:y+h, x:x+w, :] + img_poly = img_poly[y:y+h, x:x+w, :] + img_header_and_sep = img_header_and_sep[y:y+h, x:x+w] + + + img3= np.copy(img_poly) labels_con = resize_image(labels_con, img_height, img_width) @@ -347,9 +364,11 @@ class sbb_predict: tot_counter = tot_counter+1 starting_list_of_regions, index_update = update_list_and_return_first_with_length_bigger_than_one(index_update, i, pr_list, post_list,starting_list_of_regions) - + + index_sort = [i[0] for i in starting_list_of_regions ] + id_all_text = np.array(id_all_text)[index_sort] alltags=[elem.tag for elem in root_xml.iter()] @@ -389,19 +408,17 @@ class sbb_predict: for index, id_text in enumerate(id_all_text): new_element_2 = ET.SubElement(ro_subelement2, 'RegionRefIndexed') new_element_2.set('regionRef', id_all_text[index]) - new_element_2.set('index', str(index_sort[index])) + new_element_2.set('index', str(index)) - if link+'PrintSpace' in alltags: + if (link+'PrintSpace' in alltags) or (link+'Border' in alltags): page_element.insert(1, ro_subelement) else: page_element.insert(0, ro_subelement) - #page_element[0].append(new_element) - #root_xml.append(new_element) alltags=[elem.tag for elem in root_xml.iter()] ET.register_namespace("",name_space) - tree_xml.write('library2.xml',xml_declaration=True,method='xml',encoding="utf8",default_namespace=None) + tree_xml.write(os.path.join(self.out, file_name+'.xml'),xml_declaration=True,method='xml',encoding="utf8",default_namespace=None) #tree_xml.write('library2.xml') else: @@ -545,6 +562,12 @@ class sbb_predict: help="image filename", type=click.Path(exists=True, dir_okay=False), ) +@click.option( + "--out", + "-o", + help="output directory where xml with detected reading order will be written.", + type=click.Path(exists=True, file_okay=False), +) @click.option( "--patches/--no-patches", "-p/-nop", @@ -573,7 +596,7 @@ class sbb_predict: "-xml", help="xml file with layout coordinates that reading order detection will be implemented on. The result will be written in the same xml file.", ) -def main(image, model, patches, save, ground_truth, xml_file): +def main(image, model, patches, save, ground_truth, xml_file, out): with open(os.path.join(model,'config.json')) as f: config_params_model = json.load(f) task = config_params_model['task'] @@ -581,7 +604,7 @@ def main(image, model, patches, save, ground_truth, xml_file): if not save: print("Error: You used one of segmentation or binarization task but not set -s, you need a filename to save visualized output with -s") sys.exit(1) - x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth, xml_file) + x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth, xml_file, out) x.run() if __name__=="__main__": From 09789619a8fe9589352f7bde6c0e7cb41a9ea087 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Wed, 29 May 2024 13:07:06 +0200 Subject: [PATCH 061/123] min_area size of regions considered for reading order detection passed as an argument for inference --- train/gt_gen_utils.py | 13 +++++++++++-- train/inference.py | 31 ++++++++++++++++++++++++------- 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/train/gt_gen_utils.py b/train/gt_gen_utils.py index 8f72fb8..d3dd7df 100644 --- a/train/gt_gen_utils.py +++ b/train/gt_gen_utils.py @@ -32,10 +32,16 @@ def filter_contours_area_of_image_tables(image, contours, hierarchy, max_area, m jv = 0 for c in contours: + if len(np.shape(c)) == 3: + c = c[0] + elif len(np.shape(c)) == 2: + pass + #c = c[0] if len(c) < 3: # A polygon cannot have less than 3 points continue - polygon = geometry.Polygon([point[0] for point in c]) + c_e = [point for point in c] + polygon = geometry.Polygon(c_e) # area = cv2.contourArea(c) area = polygon.area # Check that polygon has area greater than minimal area @@ -49,7 +55,10 @@ def filter_contours_area_of_image(image, contours, order_index, max_area, min_ar order_index_filtered = list() #jv = 0 for jv, c in enumerate(contours): - c = c[0] + if len(np.shape(c)) == 3: + c = c[0] + elif len(np.shape(c)) == 2: + pass if len(c) < 3: # A polygon cannot have less than 3 points continue c_e = [point for point in c] diff --git a/train/inference.py b/train/inference.py index 28445e8..c7a8b02 100644 --- a/train/inference.py +++ b/train/inference.py @@ -28,7 +28,7 @@ Tool to load model and predict for given image. """ class sbb_predict: - def __init__(self,image, model, task, config_params_model, patches, save, ground_truth, xml_file, out): + def __init__(self,image, model, task, config_params_model, patches, save, ground_truth, xml_file, out, min_area): self.image=image self.patches=patches self.save=save @@ -38,6 +38,10 @@ class sbb_predict: self.config_params_model=config_params_model self.xml_file = xml_file self.out = out + if min_area: + self.min_area = float(min_area) + else: + self.min_area = 0 def resize_image(self,img_in,input_height,input_width): return cv2.resize( img_in, ( input_width,input_height) ,interpolation=cv2.INTER_NEAREST) @@ -255,11 +259,18 @@ class sbb_predict: ##texts_corr_order_index_int = [int(x) for x in texts_corr_order_index] texts_corr_order_index_int = list(np.array(range(len(co_text_all)))) - min_area = 0 - max_area = 1 + #print(texts_corr_order_index_int) - - ##co_text_all, texts_corr_order_index_int = filter_contours_area_of_image(img_poly, co_text_all, texts_corr_order_index_int, max_area, min_area) + max_area = 1 + #print(np.shape(co_text_all[0]), len( np.shape(co_text_all[0]) ),'co_text_all') + #co_text_all = filter_contours_area_of_image_tables(img_poly, co_text_all, _, max_area, min_area) + #print(co_text_all,'co_text_all') + co_text_all, texts_corr_order_index_int = filter_contours_area_of_image(img_poly, co_text_all, texts_corr_order_index_int, max_area, self.min_area) + + #print(texts_corr_order_index_int) + + #co_text_all = [co_text_all[index] for index in texts_corr_order_index_int] + id_all_text = [id_all_text[index] for index in texts_corr_order_index_int] labels_con = np.zeros((y_len,x_len,len(co_text_all)),dtype='uint8') for i in range(len(co_text_all)): @@ -596,7 +607,13 @@ class sbb_predict: "-xml", help="xml file with layout coordinates that reading order detection will be implemented on. The result will be written in the same xml file.", ) -def main(image, model, patches, save, ground_truth, xml_file, out): + +@click.option( + "--min_area", + "-min", + help="min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order.", +) +def main(image, model, patches, save, ground_truth, xml_file, out, min_area): with open(os.path.join(model,'config.json')) as f: config_params_model = json.load(f) task = config_params_model['task'] @@ -604,7 +621,7 @@ def main(image, model, patches, save, ground_truth, xml_file, out): if not save: print("Error: You used one of segmentation or binarization task but not set -s, you need a filename to save visualized output with -s") sys.exit(1) - x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth, xml_file, out) + x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth, xml_file, out, min_area) x.run() if __name__=="__main__": From 47a16464518f32427d7ff609bbc572303c2ed148 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Thu, 30 May 2024 12:56:56 +0200 Subject: [PATCH 062/123] modifying xml parsing --- train/gt_gen_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train/gt_gen_utils.py b/train/gt_gen_utils.py index d3dd7df..debaf15 100644 --- a/train/gt_gen_utils.py +++ b/train/gt_gen_utils.py @@ -122,7 +122,7 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ ## to do: add footnote to text regions for index in tqdm(range(len(gt_list))): #try: - tree1 = ET.parse(dir_in+'/'+gt_list[index]) + tree1 = ET.parse(dir_in+'/'+gt_list[index], parser = ET.XMLParser(encoding = 'iso-8859-5')) root1=tree1.getroot() alltags=[elem.tag for elem in root1.iter()] link=alltags[0].split('}')[0]+'}' @@ -658,7 +658,7 @@ def find_new_features_of_contours(contours_main): return cx_main, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, y_corr_x_min_from_argmin def read_xml(xml_file): file_name = Path(xml_file).stem - tree1 = ET.parse(xml_file) + tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding = 'iso-8859-5')) root1=tree1.getroot() alltags=[elem.tag for elem in root1.iter()] link=alltags[0].split('}')[0]+'}' From 3ef0dbdd4281bfe4cabd13765fc9723ea1e506c2 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Thu, 30 May 2024 16:59:50 +0200 Subject: [PATCH 063/123] scaling and cropping of labels and org images --- train/custom_config_page2label.json | 5 +- train/generate_gt_for_training.py | 34 ++++++-- train/gt_gen_utils.py | 125 ++++++++++++++++++++++++++-- 3 files changed, 145 insertions(+), 19 deletions(-) diff --git a/train/custom_config_page2label.json b/train/custom_config_page2label.json index e4c02cb..9116ce3 100644 --- a/train/custom_config_page2label.json +++ b/train/custom_config_page2label.json @@ -1,9 +1,8 @@ { -"use_case": "layout", +"use_case": "textline", "textregions":{ "rest_as_paragraph": 1, "header":2 , "heading":2 , "marginalia":3 }, "imageregion":4, "separatorregion":5, "graphicregions" :{"rest_as_decoration":6}, -"artificial_class_on_boundry": ["paragraph"], -"artificial_class_label":7 +"columns_width":{"1":1000, "2":1300, "3":1600, "4":2000, "5":2300, "6":2500} } diff --git a/train/generate_gt_for_training.py b/train/generate_gt_for_training.py index cf2b2a6..752090c 100644 --- a/train/generate_gt_for_training.py +++ b/train/generate_gt_for_training.py @@ -14,10 +14,22 @@ def main(): help="directory of GT page-xml files", type=click.Path(exists=True, file_okay=False), ) +@click.option( + "--dir_images", + "-di", + help="directory of org images. If print space cropping or scaling is needed for labels it would be great to provide the original images to apply the same function on them. So if -ps is not set true or in config files no columns_width key is given this argumnet can be ignored. File stems in this directory should be the same as those in dir_xml.", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--dir_out_images", + "-doi", + help="directory where the output org images after undergoing a process (like print space cropping or scaling) will be written.", + type=click.Path(exists=True, file_okay=False), +) @click.option( "--dir_out", "-do", - help="directory where ground truth images would be written", + help="directory where ground truth label images would be written", type=click.Path(exists=True, file_okay=False), ) @@ -33,8 +45,14 @@ def main(): "-to", help="this defines how output should be. A 2d image array or a 3d image array encoded with RGB color. Just pass 2d or 3d. The file will be saved one directory up. 2D image array is 3d but only information of one channel would be enough since all channels have the same values.", ) +@click.option( + "--printspace", + "-ps", + is_flag=True, + help="if this parameter set to true, generated labels and in the case of provided org images cropping will be imposed and cropped labels and images will be written in output directories.", +) -def pagexml2label(dir_xml,dir_out,type_output,config): +def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, dir_out_images): if config: with open(config) as f: config_params = json.load(f) @@ -42,7 +60,7 @@ def pagexml2label(dir_xml,dir_out,type_output,config): print("passed") config_params = None gt_list = get_content_of_dir(dir_xml) - get_images_of_ground_truth(gt_list,dir_xml,dir_out,type_output, config, config_params) + get_images_of_ground_truth(gt_list,dir_xml,dir_out,type_output, config, config_params, printspace, dir_images, dir_out_images) @main.command() @click.option( @@ -181,7 +199,7 @@ def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, i for i in range(len(texts_corr_order_index_int)): for j in range(len(texts_corr_order_index_int)): if i!=j: - input_matrix = np.zeros((input_height,input_width,3)).astype(np.int8) + input_multi_visual_modal = np.zeros((input_height,input_width,3)).astype(np.int8) final_f_name = f_name+'_'+str(indexer+indexer_start) order_class_condition = texts_corr_order_index_int[i]-texts_corr_order_index_int[j] if order_class_condition<0: @@ -189,13 +207,13 @@ def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, i else: class_type = 0 - input_matrix[:,:,0] = resize_image(labels_con[:,:,i], input_height, input_width) - input_matrix[:,:,1] = resize_image(img_poly[:,:,0], input_height, input_width) - input_matrix[:,:,2] = resize_image(labels_con[:,:,j], input_height, input_width) + input_multi_visual_modal[:,:,0] = resize_image(labels_con[:,:,i], input_height, input_width) + input_multi_visual_modal[:,:,1] = resize_image(img_poly[:,:,0], input_height, input_width) + input_multi_visual_modal[:,:,2] = resize_image(labels_con[:,:,j], input_height, input_width) np.save(os.path.join(dir_out_classes,final_f_name+'.npy' ), class_type) - cv2.imwrite(os.path.join(dir_out_modal_image,final_f_name+'.png' ), input_matrix) + cv2.imwrite(os.path.join(dir_out_modal_image,final_f_name+'.png' ), input_multi_visual_modal) indexer = indexer+1 diff --git a/train/gt_gen_utils.py b/train/gt_gen_utils.py index debaf15..d3e95e8 100644 --- a/train/gt_gen_utils.py +++ b/train/gt_gen_utils.py @@ -115,11 +115,15 @@ def update_region_contours(co_text, img_boundary, erosion_rate, dilation_rate, y img_boundary[:,:][boundary[:,:]==1] =1 return co_text_eroded, img_boundary -def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_file, config_params): +def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_file, config_params, printspace, dir_images, dir_out_images): """ Reading the page xml files and write the ground truth images into given output directory. """ ## to do: add footnote to text regions + + if dir_images: + ls_org_imgs = os.listdir(dir_images) + ls_org_imgs_stem = [item.split('.')[0] for item in ls_org_imgs] for index in tqdm(range(len(gt_list))): #try: tree1 = ET.parse(dir_in+'/'+gt_list[index], parser = ET.XMLParser(encoding = 'iso-8859-5')) @@ -133,6 +137,72 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ y_len=int(jj.attrib['imageHeight']) x_len=int(jj.attrib['imageWidth']) + if 'columns_width' in list(config_params.keys()): + columns_width_dict = config_params['columns_width'] + metadata_element = root1.find(link+'Metadata') + comment_is_sub_element = False + for child in metadata_element: + tag2 = child.tag + if tag2.endswith('}Comments') or tag2.endswith('}comments'): + text_comments = child.text + num_col = int(text_comments.split('num_col')[1]) + comment_is_sub_element = True + if not comment_is_sub_element: + num_col = None + + if num_col: + x_new = columns_width_dict[str(num_col)] + y_new = int ( x_new * (y_len / float(x_len)) ) + + if printspace: + region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace') or x.endswith('Border')]) + co_use_case = [] + + for tag in region_tags: + tag_endings = ['}PrintSpace','}Border'] + + if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]): + for nn in root1.iter(tag): + c_t_in = [] + sumi = 0 + for vv in nn.iter(): + # check the format of coords + if vv.tag == link + 'Coords': + coords = bool(vv.attrib) + if coords: + p_h = vv.attrib['points'].split(' ') + c_t_in.append( + np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])) + break + else: + pass + + if vv.tag == link + 'Point': + c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))]) + sumi += 1 + elif vv.tag != link + 'Point' and sumi >= 1: + break + co_use_case.append(np.array(c_t_in)) + + img = np.zeros((y_len, x_len, 3)) + + img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1)) + + img_poly = img_poly.astype(np.uint8) + + imgray = cv2.cvtColor(img_poly, cv2.COLOR_BGR2GRAY) + _, thresh = cv2.threshold(imgray, 0, 255, 0) + + contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + + cnt_size = np.array([cv2.contourArea(contours[j]) for j in range(len(contours))]) + + cnt = contours[np.argmax(cnt_size)] + + x, y, w, h = cv2.boundingRect(cnt) + bb_xywh = [x, y, w, h] + + if config_file and (config_params['use_case']=='textline' or config_params['use_case']=='word' or config_params['use_case']=='glyph' or config_params['use_case']=='printspace'): keys = list(config_params.keys()) if "artificial_class_label" in keys: @@ -186,7 +256,6 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ co_use_case.append(np.array(c_t_in)) - if "artificial_class_label" in keys: img_boundary = np.zeros((y_len, x_len)) erosion_rate = 1 @@ -205,12 +274,32 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ img_poly[:,:,0][img_boundary[:,:]==1] = artificial_class_rgb_color[0] img_poly[:,:,1][img_boundary[:,:]==1] = artificial_class_rgb_color[1] img_poly[:,:,2][img_boundary[:,:]==1] = artificial_class_rgb_color[2] + + + if printspace and config_params['use_case']!='printspace': + img_poly = img_poly[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :] + + if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace': + img_poly = resize_image(img_poly, y_new, x_new) try: - cv2.imwrite(output_dir + '/' + gt_list[index].split('-')[1].split('.')[0] + '.png', - img_poly) + xml_file_stem = gt_list[index].split('-')[1].split('.')[0] + cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly) except: - cv2.imwrite(output_dir + '/' + gt_list[index].split('.')[0] + '.png', img_poly) + xml_file_stem = gt_list[index].split('.')[0] + cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly) + + if dir_images: + org_image_name = ls_org_imgs[ls_org_imgs_stem.index(xml_file_stem)] + img_org = cv2.imread(os.path.join(dir_images, org_image_name)) + + if printspace and config_params['use_case']!='printspace': + img_org = img_org[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :] + + if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace': + img_org = resize_image(img_org, y_new, x_new) + + cv2.imwrite(os.path.join(dir_out_images, org_image_name), img_org) if config_file and config_params['use_case']=='layout': @@ -616,11 +705,31 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ + if printspace: + img_poly = img_poly[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :] - try: - cv2.imwrite(output_dir+'/'+gt_list[index].split('-')[1].split('.')[0]+'.png',img_poly ) + if 'columns_width' in list(config_params.keys()) and num_col: + img_poly = resize_image(img_poly, y_new, x_new) + + try: + xml_file_stem = gt_list[index].split('-')[1].split('.')[0] + cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly) except: - cv2.imwrite(output_dir+'/'+gt_list[index].split('.')[0]+'.png',img_poly ) + xml_file_stem = gt_list[index].split('.')[0] + cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly) + + + if dir_images: + org_image_name = ls_org_imgs[ls_org_imgs_stem.index(xml_file_stem)] + img_org = cv2.imread(os.path.join(dir_images, org_image_name)) + + if printspace: + img_org = img_org[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :] + + if 'columns_width' in list(config_params.keys()) and num_col: + img_org = resize_image(img_org, y_new, x_new) + + cv2.imwrite(os.path.join(dir_out_images, org_image_name), img_org) From 13ebe71d1349d5802d9ff5aa1e79e95141185371 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Thu, 6 Jun 2024 14:38:29 +0200 Subject: [PATCH 064/123] replacement in a list done correctly --- train/gt_gen_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train/gt_gen_utils.py b/train/gt_gen_utils.py index d3e95e8..38e77e8 100644 --- a/train/gt_gen_utils.py +++ b/train/gt_gen_utils.py @@ -636,7 +636,7 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ if 'textregions' in keys: if 'rest_as_paragraph' in types_text: - types_text[types_text=='rest_as_paragraph'] = 'paragraph' + types_text = ['paragraph'if ttind=='rest_as_paragraph' else ttind for ttind in types_text] for element_text in types_text: if element_text == 'paragraph': color_label = labels_rgb_color[ config_params['textregions']['rest_as_paragraph']] @@ -688,7 +688,7 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ if 'textregions' in keys: if 'rest_as_paragraph' in types_text: - types_text[types_text=='rest_as_paragraph'] = 'paragraph' + types_text = ['paragraph'if ttind=='rest_as_paragraph' else ttind for ttind in types_text] for element_text in types_text: if element_text == 'paragraph': color_label = config_params['textregions']['rest_as_paragraph'] From 742e3c2aa28171cbeff8517cf49ab779d196ee23 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Thu, 6 Jun 2024 14:46:06 +0200 Subject: [PATCH 065/123] Update README.md --- train/README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/train/README.md b/train/README.md index 899c9a3..b9e70a8 100644 --- a/train/README.md +++ b/train/README.md @@ -73,3 +73,6 @@ The output folder should be an empty folder where the output model will be writt * weighted_loss: If ``true``, this means that you want to apply weighted categorical_crossentropy as loss fucntion. Be carefull if you set to ``true``the parameter "is_loss_soft_dice" should be ``false`` * data_is_provided: If you have already provided the input data you can set this to ``true``. Be sure that the train and eval data are in "dir_output". Since when once we provide training data we resize and augment them and then we write them in sub-directories train and eval in "dir_output". * dir_train: This is the directory of "images" and "labels" (dir_train should include two subdirectories with names of images and labels ) for raw images and labels. Namely they are not prepared (not resized and not augmented) yet for training the model. When we run this tool these raw data will be transformed to suitable size needed for the model and they will be written in "dir_output" in train and eval directories. Each of train and eval include "images" and "labels" sub-directories. + +#### Additional documentation +Please check the [wiki](https://github.com/qurator-spk/sbb_pixelwise_segmentation/wiki). From 5a5914e06c1185f24de378dc752892e699c0446b Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Thu, 6 Jun 2024 18:45:47 +0200 Subject: [PATCH 066/123] just defined textregion types can be extracted as label --- train/gt_gen_utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/train/gt_gen_utils.py b/train/gt_gen_utils.py index 38e77e8..86eb0a1 100644 --- a/train/gt_gen_utils.py +++ b/train/gt_gen_utils.py @@ -325,6 +325,7 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ region_tags=np.unique([x for x in alltags if x.endswith('Region')]) co_text = {'drop-capital':[], "footnote":[], "footnote-continued":[], "heading":[], "signature-mark":[], "header":[], "catch-word":[], "page-number":[], "marginalia":[], "paragraph":[]} + all_defined_textregion_types = list(co_text.keys()) co_graphic = {"handwritten-annotation":[], "decoration":[], "stamp":[], "signature":[]} co_sep=[] co_img=[] @@ -359,7 +360,8 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ else: if "type" in nn.attrib: - c_t_in[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + if nn.attrib['type'] in all_defined_textregion_types: + c_t_in[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) break else: @@ -384,8 +386,9 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ else: if "type" in nn.attrib: - c_t_in[nn.attrib['type']].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) - sumi+=1 + if nn.attrib['type'] in all_defined_textregion_types: + c_t_in[nn.attrib['type']].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 elif vv.tag!=link+'Point' and sumi>=1: From 4c376289e97890a55755e72198d20fde37dd1146 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Thu, 6 Jun 2024 18:55:22 +0200 Subject: [PATCH 067/123] just defined graphic region types can be extracted as label --- train/gt_gen_utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/train/gt_gen_utils.py b/train/gt_gen_utils.py index 86eb0a1..c2360fc 100644 --- a/train/gt_gen_utils.py +++ b/train/gt_gen_utils.py @@ -327,6 +327,7 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ co_text = {'drop-capital':[], "footnote":[], "footnote-continued":[], "heading":[], "signature-mark":[], "header":[], "catch-word":[], "page-number":[], "marginalia":[], "paragraph":[]} all_defined_textregion_types = list(co_text.keys()) co_graphic = {"handwritten-annotation":[], "decoration":[], "stamp":[], "signature":[]} + all_defined_graphic_types = list(co_graphic.keys()) co_sep=[] co_img=[] co_table=[] @@ -425,7 +426,8 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ else: if "type" in nn.attrib: - c_t_in_graphic[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + if nn.attrib['type'] in all_defined_graphic_types: + c_t_in_graphic[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) break else: @@ -450,8 +452,9 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ else: if "type" in nn.attrib: - c_t_in_graphic[nn.attrib['type']].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) - sumi+=1 + if nn.attrib['type'] in all_defined_graphic_types: + c_t_in_graphic[nn.attrib['type']].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 elif vv.tag!=link+'Point' and sumi>=1: break From cc91e4b12c42076f76bf3e8409c050ad80e9cf78 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Fri, 7 Jun 2024 16:24:31 +0200 Subject: [PATCH 068/123] updating train.py --- train/train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/train/train.py b/train/train.py index f338c78..e16745f 100644 --- a/train/train.py +++ b/train/train.py @@ -59,6 +59,8 @@ def config_params(): pretraining = False # Set to true to load pretrained weights of ResNet50 encoder. scaling_bluring = False # If true, a combination of scaling and blurring will be applied to the image. scaling_binarization = False # If true, a combination of scaling and binarization will be applied to the image. + rotation = False # If true, a 90 degree rotation will be implemeneted. + rotation_not_90 = False # If true rotation based on provided angles with thetha will be implemeneted. scaling_brightness = False # If true, a combination of scaling and brightening will be applied to the image. scaling_flip = False # If true, a combination of scaling and flipping will be applied to the image. thetha = None # Rotate image by these angles for augmentation. From 1921e6754f7abbafb5f7f2731f2d29588bf4eac6 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Mon, 10 Jun 2024 22:15:30 +0200 Subject: [PATCH 069/123] updating train.py nontransformer backend --- train/models.py | 13 +++++++++---- train/train.py | 12 +++++++++--- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/train/models.py b/train/models.py index d852ac3..b8b0d27 100644 --- a/train/models.py +++ b/train/models.py @@ -30,8 +30,8 @@ class Patches(layers.Layer): self.patch_size = patch_size def call(self, images): - print(tf.shape(images)[1],'images') - print(self.patch_size,'self.patch_size') + #print(tf.shape(images)[1],'images') + #print(self.patch_size,'self.patch_size') batch_size = tf.shape(images)[0] patches = tf.image.extract_patches( images=images, @@ -41,7 +41,7 @@ class Patches(layers.Layer): padding="VALID", ) patch_dims = patches.shape[-1] - print(patches.shape,patch_dims,'patch_dims') + #print(patches.shape,patch_dims,'patch_dims') patches = tf.reshape(patches, [batch_size, -1, patch_dims]) return patches def get_config(self): @@ -51,6 +51,7 @@ class Patches(layers.Layer): 'patch_size': self.patch_size, }) return config + class PatchEncoder(layers.Layer): def __init__(self, num_patches, projection_dim): @@ -408,7 +409,11 @@ def vit_resnet50_unet(n_classes, patch_size, num_patches, input_height=224, inpu if pretraining: model = Model(inputs, x).load_weights(resnet50_Weights_path) - num_patches = x.shape[1]*x.shape[2] + #num_patches = x.shape[1]*x.shape[2] + + #patch_size_y = input_height / x.shape[1] + #patch_size_x = input_width / x.shape[2] + #patch_size = patch_size_x * patch_size_y patches = Patches(patch_size)(x) # Encode patches. encoded_patches = PatchEncoder(num_patches, projection_dim)(patches) diff --git a/train/train.py b/train/train.py index e16745f..84c9d3b 100644 --- a/train/train.py +++ b/train/train.py @@ -97,8 +97,6 @@ def run(_config, n_classes, n_epochs, input_height, pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name): if task == "segmentation" or task == "enhancement": - - num_patches = transformer_num_patches_xy[0]*transformer_num_patches_xy[1] if data_is_provided: dir_train_flowing = os.path.join(dir_output, 'train') dir_eval_flowing = os.path.join(dir_output, 'eval') @@ -213,7 +211,15 @@ def run(_config, n_classes, n_epochs, input_height, index_start = 0 if backbone_type=='nontransformer': model = resnet50_unet(n_classes, input_height, input_width, task, weight_decay, pretraining) - elif backbone_type=='nontransformer': + elif backbone_type=='transformer': + num_patches = transformer_num_patches_xy[0]*transformer_num_patches_xy[1] + + if not (num_patches == (input_width / 32) * (input_height / 32)): + print("Error: transformer num patches error. Parameter transformer_num_patches_xy should be set to (input_width/32) = {} and (input_height/32) = {}".format(int(input_width / 32), int(input_height / 32)) ) + sys.exit(1) + if not (transformer_patchsize == 1): + print("Error: transformer patchsize error. Parameter transformer_patchsizeshould set to 1" ) + sys.exit(1) model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width, task, weight_decay, pretraining) #if you want to see the model structure just uncomment model summary. From 29da23da7663ade94f9dc158ba9cd04a39a6f114 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Tue, 11 Jun 2024 17:48:30 +0200 Subject: [PATCH 070/123] binarization as a separate task of segmentation --- train/train.py | 13 +++++++------ train/utils.py | 4 ++-- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/train/train.py b/train/train.py index 84c9d3b..9e06a66 100644 --- a/train/train.py +++ b/train/train.py @@ -96,7 +96,7 @@ def run(_config, n_classes, n_epochs, input_height, transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output, pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name): - if task == "segmentation" or task == "enhancement": + if task == "segmentation" or task == "enhancement" or task == "binarization": if data_is_provided: dir_train_flowing = os.path.join(dir_output, 'train') dir_eval_flowing = os.path.join(dir_output, 'eval') @@ -194,16 +194,16 @@ def run(_config, n_classes, n_epochs, input_height, if continue_training: if backbone_type=='nontransformer': - if is_loss_soft_dice and task == "segmentation": + if is_loss_soft_dice and (task == "segmentation" or task == "binarization"): model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss}) - if weighted_loss and task == "segmentation": + if weighted_loss and (task == "segmentation" or task == "binarization"): model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)}) if not is_loss_soft_dice and not weighted_loss: model = load_model(dir_of_start_model , compile=True) elif backbone_type=='transformer': - if is_loss_soft_dice and task == "segmentation": + if is_loss_soft_dice and (task == "segmentation" or task == "binarization"): model = load_model(dir_of_start_model, compile=True, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches,'soft_dice_loss': soft_dice_loss}) - if weighted_loss and task == "segmentation": + if weighted_loss and (task == "segmentation" or task == "binarization"): model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)}) if not is_loss_soft_dice and not weighted_loss: model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches}) @@ -224,8 +224,9 @@ def run(_config, n_classes, n_epochs, input_height, #if you want to see the model structure just uncomment model summary. #model.summary() + - if task == "segmentation": + if (task == "segmentation" or task == "binarization"): if not is_loss_soft_dice and not weighted_loss: model.compile(loss='categorical_crossentropy', optimizer=Adam(lr=learning_rate), metrics=['accuracy']) diff --git a/train/utils.py b/train/utils.py index a2e8a9c..605d8d1 100644 --- a/train/utils.py +++ b/train/utils.py @@ -309,7 +309,7 @@ def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_c interpolation=cv2.INTER_NEAREST) # Read an image from folder and resize img[i - c] = train_img # add to array - img[0], img[1], and so on. - if task == "segmentation": + if task == "segmentation" or task=="binarization": train_mask = cv2.imread(mask_folder + '/' + filename + '.png') train_mask = get_one_hot(resize_image(train_mask, input_height, input_width), input_height, input_width, n_classes) @@ -569,7 +569,7 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow indexer = 0 for im, seg_i in tqdm(zip(imgs_list_train, segs_list_train)): img_name = im.split('.')[0] - if task == "segmentation": + if task == "segmentation" or task == "binarization": dir_of_label_file = os.path.join(dir_seg, img_name + '.png') elif task=="enhancement": dir_of_label_file = os.path.join(dir_seg, im) From 95faf1a4c8bc25ffe6d89fa2d296fccf95479e18 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Wed, 12 Jun 2024 13:26:27 +0200 Subject: [PATCH 071/123] transformer patch size is dynamic now. --- train/config_params.json | 28 +++++++++++++----------- train/models.py | 47 ++++++++++++++++++++++++++++++++-------- train/train.py | 30 ++++++++++++++++++------- 3 files changed, 75 insertions(+), 30 deletions(-) diff --git a/train/config_params.json b/train/config_params.json index 8a56de5..6b8b6ed 100644 --- a/train/config_params.json +++ b/train/config_params.json @@ -1,42 +1,44 @@ { - "backbone_type" : "nontransformer", - "task": "classification", + "backbone_type" : "transformer", + "task": "binarization", "n_classes" : 2, - "n_epochs" : 20, - "input_height" : 448, - "input_width" : 448, + "n_epochs" : 1, + "input_height" : 224, + "input_width" : 672, "weight_decay" : 1e-6, - "n_batch" : 6, + "n_batch" : 1, "learning_rate": 1e-4, - "f1_threshold_classification": 0.8, "patches" : true, "pretraining" : true, "augmentation" : false, "flip_aug" : false, "blur_aug" : false, "scaling" : true, + "degrading": false, + "brightening": false, "binarization" : false, "scaling_bluring" : false, "scaling_binarization" : false, "scaling_flip" : false, "rotation": false, "rotation_not_90": false, - "transformer_num_patches_xy": [28, 28], - "transformer_patchsize": 1, + "transformer_num_patches_xy": [7, 7], + "transformer_patchsize_x": 3, + "transformer_patchsize_y": 1, + "transformer_projection_dim": 192, "blur_k" : ["blur","guass","median"], "scales" : [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.4], "brightness" : [1.3, 1.5, 1.7, 2], "degrade_scales" : [0.2, 0.4], "flip_index" : [0, 1, -1], "thetha" : [10, -10], - "classification_classes_name" : {"0":"apple", "1":"orange"}, "continue_training": false, "index_start" : 0, "dir_of_start_model" : " ", "weighted_loss": false, "is_loss_soft_dice": false, "data_is_provided": false, - "dir_train": "./train", - "dir_eval": "./eval", - "dir_output": "./output" + "dir_train": "/home/vahid/Documents/test/training_data_sample_binarization", + "dir_eval": "/home/vahid/Documents/test/eval", + "dir_output": "/home/vahid/Documents/test/out" } diff --git a/train/models.py b/train/models.py index b8b0d27..1abf304 100644 --- a/train/models.py +++ b/train/models.py @@ -6,25 +6,49 @@ from tensorflow.keras import layers from tensorflow.keras.regularizers import l2 mlp_head_units = [2048, 1024] -projection_dim = 64 +#projection_dim = 64 transformer_layers = 8 num_heads = 4 resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5' IMAGE_ORDERING = 'channels_last' MERGE_AXIS = -1 -transformer_units = [ - projection_dim * 2, - projection_dim, -] # Size of the transformer layers def mlp(x, hidden_units, dropout_rate): for units in hidden_units: x = layers.Dense(units, activation=tf.nn.gelu)(x) x = layers.Dropout(dropout_rate)(x) return x - class Patches(layers.Layer): + def __init__(self, patch_size_x, patch_size_y):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs): + super(Patches, self).__init__() + self.patch_size_x = patch_size_x + self.patch_size_y = patch_size_y + + def call(self, images): + #print(tf.shape(images)[1],'images') + #print(self.patch_size,'self.patch_size') + batch_size = tf.shape(images)[0] + patches = tf.image.extract_patches( + images=images, + sizes=[1, self.patch_size_y, self.patch_size_x, 1], + strides=[1, self.patch_size_y, self.patch_size_x, 1], + rates=[1, 1, 1, 1], + padding="VALID", + ) + patch_dims = patches.shape[-1] + patches = tf.reshape(patches, [batch_size, -1, patch_dims]) + return patches + def get_config(self): + + config = super().get_config().copy() + config.update({ + 'patch_size_x': self.patch_size_x, + 'patch_size_y': self.patch_size_y, + }) + return config + +class Patches_old(layers.Layer): def __init__(self, patch_size):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs): super(Patches, self).__init__() self.patch_size = patch_size @@ -369,8 +393,13 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentati return model -def vit_resnet50_unet(n_classes, patch_size, num_patches, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): +def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): inputs = layers.Input(shape=(input_height, input_width, 3)) + + transformer_units = [ + projection_dim * 2, + projection_dim, + ] # Size of the transformer layers IMAGE_ORDERING = 'channels_last' bn_axis=3 @@ -414,7 +443,7 @@ def vit_resnet50_unet(n_classes, patch_size, num_patches, input_height=224, inpu #patch_size_y = input_height / x.shape[1] #patch_size_x = input_width / x.shape[2] #patch_size = patch_size_x * patch_size_y - patches = Patches(patch_size)(x) + patches = Patches(patch_size_x, patch_size_y)(x) # Encode patches. encoded_patches = PatchEncoder(num_patches, projection_dim)(patches) @@ -434,7 +463,7 @@ def vit_resnet50_unet(n_classes, patch_size, num_patches, input_height=224, inpu # Skip connection 2. encoded_patches = layers.Add()([x3, x2]) - encoded_patches = tf.reshape(encoded_patches, [-1, x.shape[1], x.shape[2], 64]) + encoded_patches = tf.reshape(encoded_patches, [-1, x.shape[1], x.shape[2] , int( projection_dim / (patch_size_x * patch_size_y) )]) v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(encoded_patches) v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048) diff --git a/train/train.py b/train/train.py index 9e06a66..bafcc9e 100644 --- a/train/train.py +++ b/train/train.py @@ -70,8 +70,10 @@ def config_params(): brightness = None # Brighten image for augmentation. flip_index = None # Flip image for augmentation. continue_training = False # Set to true if you would like to continue training an already trained a model. - transformer_patchsize = None # Patch size of vision transformer patches. + transformer_patchsize_x = None # Patch size of vision transformer patches. + transformer_patchsize_y = None transformer_num_patches_xy = None # Number of patches for vision transformer. + transformer_projection_dim = 64 # Transformer projection dimension index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3. dir_of_start_model = '' # Directory containing pretrained encoder to continue training the model. is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false. @@ -92,7 +94,7 @@ def run(_config, n_classes, n_epochs, input_height, brightening, binarization, blur_k, scales, degrade_scales, brightness, dir_train, data_is_provided, scaling_bluring, scaling_brightness, scaling_binarization, rotation, rotation_not_90, - thetha, scaling_flip, continue_training, transformer_patchsize, + thetha, scaling_flip, continue_training, transformer_projection_dim, transformer_patchsize_x, transformer_patchsize_y, transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output, pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name): @@ -212,15 +214,27 @@ def run(_config, n_classes, n_epochs, input_height, if backbone_type=='nontransformer': model = resnet50_unet(n_classes, input_height, input_width, task, weight_decay, pretraining) elif backbone_type=='transformer': - num_patches = transformer_num_patches_xy[0]*transformer_num_patches_xy[1] + num_patches_x = transformer_num_patches_xy[0] + num_patches_y = transformer_num_patches_xy[1] + num_patches = num_patches_x * num_patches_y - if not (num_patches == (input_width / 32) * (input_height / 32)): - print("Error: transformer num patches error. Parameter transformer_num_patches_xy should be set to (input_width/32) = {} and (input_height/32) = {}".format(int(input_width / 32), int(input_height / 32)) ) + ##if not (num_patches == (input_width / 32) * (input_height / 32)): + ##print("Error: transformer num patches error. Parameter transformer_num_patches_xy should be set to (input_width/32) = {} and (input_height/32) = {}".format(int(input_width / 32), int(input_height / 32)) ) + ##sys.exit(1) + #if not (transformer_patchsize == 1): + #print("Error: transformer patchsize error. Parameter transformer_patchsizeshould set to 1" ) + #sys.exit(1) + if (input_height != (num_patches_y * transformer_patchsize_y * 32) ): + print("Error: transformer_patchsize_y or transformer_num_patches_xy height value error . input_height should be equal to ( transformer_num_patches_xy height value * transformer_patchsize_y * 32)") sys.exit(1) - if not (transformer_patchsize == 1): - print("Error: transformer patchsize error. Parameter transformer_patchsizeshould set to 1" ) + if (input_width != (num_patches_x * transformer_patchsize_x * 32) ): + print("Error: transformer_patchsize_x or transformer_num_patches_xy width value error . input_width should be equal to ( transformer_num_patches_xy width value * transformer_patchsize_x * 32)") sys.exit(1) - model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width, task, weight_decay, pretraining) + if (transformer_projection_dim % (transformer_patchsize_y * transformer_patchsize_x)) != 0: + print("Error: transformer_projection_dim error. The remainder when parameter transformer_projection_dim is divided by (transformer_patchsize_y*transformer_patchsize_x) should be zero") + sys.exit(1) + + model = vit_resnet50_unet(n_classes, transformer_patchsize_x, transformer_patchsize_y, num_patches, transformer_projection_dim, input_height, input_width, task, weight_decay, pretraining) #if you want to see the model structure just uncomment model summary. #model.summary() From 22d7359db2b1660272a32dd2e43f69f67373883f Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Wed, 12 Jun 2024 17:39:57 +0200 Subject: [PATCH 072/123] Transformer+CNN structure is added to vision transformer type --- train/config_params.json | 16 +++-- train/models.py | 142 ++++++++++++++++++++++++++++++++++++--- train/train.py | 57 ++++++++++------ 3 files changed, 176 insertions(+), 39 deletions(-) diff --git a/train/config_params.json b/train/config_params.json index 6b8b6ed..d72530e 100644 --- a/train/config_params.json +++ b/train/config_params.json @@ -2,9 +2,9 @@ "backbone_type" : "transformer", "task": "binarization", "n_classes" : 2, - "n_epochs" : 1, + "n_epochs" : 2, "input_height" : 224, - "input_width" : 672, + "input_width" : 224, "weight_decay" : 1e-6, "n_batch" : 1, "learning_rate": 1e-4, @@ -22,10 +22,14 @@ "scaling_flip" : false, "rotation": false, "rotation_not_90": false, - "transformer_num_patches_xy": [7, 7], - "transformer_patchsize_x": 3, - "transformer_patchsize_y": 1, - "transformer_projection_dim": 192, + "transformer_num_patches_xy": [56, 56], + "transformer_patchsize_x": 4, + "transformer_patchsize_y": 4, + "transformer_projection_dim": 64, + "transformer_mlp_head_units": [128, 64], + "transformer_layers": 1, + "transformer_num_heads": 1, + "transformer_cnn_first": false, "blur_k" : ["blur","guass","median"], "scales" : [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.4], "brightness" : [1.3, 1.5, 1.7, 2], diff --git a/train/models.py b/train/models.py index 1abf304..8841bd3 100644 --- a/train/models.py +++ b/train/models.py @@ -5,10 +5,10 @@ from tensorflow.keras.layers import * from tensorflow.keras import layers from tensorflow.keras.regularizers import l2 -mlp_head_units = [2048, 1024] -#projection_dim = 64 -transformer_layers = 8 -num_heads = 4 +##mlp_head_units = [512, 256]#[2048, 1024] +###projection_dim = 64 +##transformer_layers = 2#8 +##num_heads = 1#4 resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5' IMAGE_ORDERING = 'channels_last' MERGE_AXIS = -1 @@ -36,7 +36,8 @@ class Patches(layers.Layer): rates=[1, 1, 1, 1], padding="VALID", ) - patch_dims = patches.shape[-1] + #patch_dims = patches.shape[-1] + patch_dims = tf.shape(patches)[-1] patches = tf.reshape(patches, [batch_size, -1, patch_dims]) return patches def get_config(self): @@ -393,13 +394,13 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentati return model -def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): +def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=[128, 64], transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): inputs = layers.Input(shape=(input_height, input_width, 3)) - transformer_units = [ - projection_dim * 2, - projection_dim, - ] # Size of the transformer layers + #transformer_units = [ + #projection_dim * 2, + #projection_dim, + #] # Size of the transformer layers IMAGE_ORDERING = 'channels_last' bn_axis=3 @@ -459,7 +460,7 @@ def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, projec # Layer normalization 2. x3 = layers.LayerNormalization(epsilon=1e-6)(x2) # MLP. - x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1) + x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1) # Skip connection 2. encoded_patches = layers.Add()([x3, x2]) @@ -515,6 +516,125 @@ def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, projec return model +def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=[128, 64], transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): + inputs = layers.Input(shape=(input_height, input_width, 3)) + + ##transformer_units = [ + ##projection_dim * 2, + ##projection_dim, + ##] # Size of the transformer layers + IMAGE_ORDERING = 'channels_last' + bn_axis=3 + + patches = Patches(patch_size_x, patch_size_y)(inputs) + # Encode patches. + encoded_patches = PatchEncoder(num_patches, projection_dim)(patches) + + for _ in range(transformer_layers): + # Layer normalization 1. + x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches) + # Create a multi-head attention layer. + attention_output = layers.MultiHeadAttention( + num_heads=num_heads, key_dim=projection_dim, dropout=0.1 + )(x1, x1) + # Skip connection 1. + x2 = layers.Add()([attention_output, encoded_patches]) + # Layer normalization 2. + x3 = layers.LayerNormalization(epsilon=1e-6)(x2) + # MLP. + x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1) + # Skip connection 2. + encoded_patches = layers.Add()([x3, x2]) + + encoded_patches = tf.reshape(encoded_patches, [-1, input_height, input_width , int( projection_dim / (patch_size_x * patch_size_y) )]) + + encoded_patches = Conv2D(3, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay), name='convinput')(encoded_patches) + + x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(encoded_patches) + x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x) + f1 = x + + x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) + x = Activation('relu')(x) + x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x) + + x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) + x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') + x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') + f2 = one_side_pad(x) + + x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') + f3 = x + + x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') + f4 = x + + x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') + f5 = x + + if pretraining: + model = Model(encoded_patches, x).load_weights(resnet50_Weights_path) + + v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(x) + v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048) + v1024_2048 = Activation('relu')(v1024_2048) + + o = (UpSampling2D( (2, 2), data_format=IMAGE_ORDERING))(v1024_2048) + o = (concatenate([o, f4],axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o ,f3], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, f2], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, f1], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, inputs],axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(o) + if task == "segmentation": + o = (BatchNormalization(axis=bn_axis))(o) + o = (Activation('softmax'))(o) + else: + o = (Activation('sigmoid'))(o) + + model = Model(inputs=inputs, outputs=o) + + return model + def resnet50_classifier(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False): include_top=True assert input_height%32 == 0 diff --git a/train/train.py b/train/train.py index bafcc9e..71f31f3 100644 --- a/train/train.py +++ b/train/train.py @@ -70,10 +70,14 @@ def config_params(): brightness = None # Brighten image for augmentation. flip_index = None # Flip image for augmentation. continue_training = False # Set to true if you would like to continue training an already trained a model. - transformer_patchsize_x = None # Patch size of vision transformer patches. - transformer_patchsize_y = None - transformer_num_patches_xy = None # Number of patches for vision transformer. - transformer_projection_dim = 64 # Transformer projection dimension + transformer_patchsize_x = None # Patch size of vision transformer patches in x direction. + transformer_patchsize_y = None # Patch size of vision transformer patches in y direction. + transformer_num_patches_xy = None # Number of patches for vision transformer in x and y direction respectively. + transformer_projection_dim = 64 # Transformer projection dimension. Default value is 64. + transformer_mlp_head_units = [128, 64] # Transformer Multilayer Perceptron (MLP) head units. Default value is [128, 64] + transformer_layers = 8 # transformer layers. Default value is 8. + transformer_num_heads = 4 # Transformer number of heads. Default value is 4. + transformer_cnn_first = True # We have two types of vision transformers. In one type, a CNN is applied first, followed by a transformer. In the other type, this order is reversed. If transformer_cnn_first is true, it means the CNN will be applied before the transformer. Default value is true. index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3. dir_of_start_model = '' # Directory containing pretrained encoder to continue training the model. is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false. @@ -94,7 +98,9 @@ def run(_config, n_classes, n_epochs, input_height, brightening, binarization, blur_k, scales, degrade_scales, brightness, dir_train, data_is_provided, scaling_bluring, scaling_brightness, scaling_binarization, rotation, rotation_not_90, - thetha, scaling_flip, continue_training, transformer_projection_dim, transformer_patchsize_x, transformer_patchsize_y, + thetha, scaling_flip, continue_training, transformer_projection_dim, + transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_cnn_first, + transformer_patchsize_x, transformer_patchsize_y, transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output, pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name): @@ -218,26 +224,33 @@ def run(_config, n_classes, n_epochs, input_height, num_patches_y = transformer_num_patches_xy[1] num_patches = num_patches_x * num_patches_y - ##if not (num_patches == (input_width / 32) * (input_height / 32)): - ##print("Error: transformer num patches error. Parameter transformer_num_patches_xy should be set to (input_width/32) = {} and (input_height/32) = {}".format(int(input_width / 32), int(input_height / 32)) ) - ##sys.exit(1) - #if not (transformer_patchsize == 1): - #print("Error: transformer patchsize error. Parameter transformer_patchsizeshould set to 1" ) - #sys.exit(1) - if (input_height != (num_patches_y * transformer_patchsize_y * 32) ): - print("Error: transformer_patchsize_y or transformer_num_patches_xy height value error . input_height should be equal to ( transformer_num_patches_xy height value * transformer_patchsize_y * 32)") - sys.exit(1) - if (input_width != (num_patches_x * transformer_patchsize_x * 32) ): - print("Error: transformer_patchsize_x or transformer_num_patches_xy width value error . input_width should be equal to ( transformer_num_patches_xy width value * transformer_patchsize_x * 32)") - sys.exit(1) - if (transformer_projection_dim % (transformer_patchsize_y * transformer_patchsize_x)) != 0: - print("Error: transformer_projection_dim error. The remainder when parameter transformer_projection_dim is divided by (transformer_patchsize_y*transformer_patchsize_x) should be zero") - sys.exit(1) + if transformer_cnn_first: + if (input_height != (num_patches_y * transformer_patchsize_y * 32) ): + print("Error: transformer_patchsize_y or transformer_num_patches_xy height value error . input_height should be equal to ( transformer_num_patches_xy height value * transformer_patchsize_y * 32)") + sys.exit(1) + if (input_width != (num_patches_x * transformer_patchsize_x * 32) ): + print("Error: transformer_patchsize_x or transformer_num_patches_xy width value error . input_width should be equal to ( transformer_num_patches_xy width value * transformer_patchsize_x * 32)") + sys.exit(1) + if (transformer_projection_dim % (transformer_patchsize_y * transformer_patchsize_x)) != 0: + print("Error: transformer_projection_dim error. The remainder when parameter transformer_projection_dim is divided by (transformer_patchsize_y*transformer_patchsize_x) should be zero") + sys.exit(1) + - model = vit_resnet50_unet(n_classes, transformer_patchsize_x, transformer_patchsize_y, num_patches, transformer_projection_dim, input_height, input_width, task, weight_decay, pretraining) + model = vit_resnet50_unet(n_classes, transformer_patchsize_x, transformer_patchsize_y, num_patches, transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_projection_dim, input_height, input_width, task, weight_decay, pretraining) + else: + if (input_height != (num_patches_y * transformer_patchsize_y) ): + print("Error: transformer_patchsize_y or transformer_num_patches_xy height value error . input_height should be equal to ( transformer_num_patches_xy height value * transformer_patchsize_y)") + sys.exit(1) + if (input_width != (num_patches_x * transformer_patchsize_x) ): + print("Error: transformer_patchsize_x or transformer_num_patches_xy width value error . input_width should be equal to ( transformer_num_patches_xy width value * transformer_patchsize_x)") + sys.exit(1) + if (transformer_projection_dim % (transformer_patchsize_y * transformer_patchsize_x)) != 0: + print("Error: transformer_projection_dim error. The remainder when parameter transformer_projection_dim is divided by (transformer_patchsize_y*transformer_patchsize_x) should be zero") + sys.exit(1) + model = vit_resnet50_unet_transformer_before_cnn(n_classes, transformer_patchsize_x, transformer_patchsize_y, num_patches, transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_projection_dim, input_height, input_width, task, weight_decay, pretraining) #if you want to see the model structure just uncomment model summary. - #model.summary() + model.summary() if (task == "segmentation" or task == "binarization"): From 66022cf771dafd0cafa0734b545e60fc44fa07af Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Wed, 12 Jun 2024 17:40:40 +0200 Subject: [PATCH 073/123] update config --- train/config_params.json | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train/config_params.json b/train/config_params.json index d72530e..a89cbb5 100644 --- a/train/config_params.json +++ b/train/config_params.json @@ -42,7 +42,7 @@ "weighted_loss": false, "is_loss_soft_dice": false, "data_is_provided": false, - "dir_train": "/home/vahid/Documents/test/training_data_sample_binarization", - "dir_eval": "/home/vahid/Documents/test/eval", - "dir_output": "/home/vahid/Documents/test/out" + "dir_train": "./train", + "dir_eval": "./eval", + "dir_output": "./output" } From b3cd01de3761ce251b9171aa8f48318d926594f5 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Fri, 21 Jun 2024 13:06:26 +0200 Subject: [PATCH 074/123] update reading order machine based --- train/generate_gt_for_training.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/train/generate_gt_for_training.py b/train/generate_gt_for_training.py index 752090c..cfcc151 100644 --- a/train/generate_gt_for_training.py +++ b/train/generate_gt_for_training.py @@ -163,8 +163,7 @@ def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, i #print('########################') xml_file = os.path.join(dir_xml,ind_xml ) f_name = ind_xml.split('.')[0] - file_name, id_paragraph, id_header,co_text_paragraph,\ - co_text_header,tot_region_ref,x_len, y_len,index_tot_regions,img_poly = read_xml(xml_file) + _, _, _, file_name, id_paragraph, id_header,co_text_paragraph,co_text_header,tot_region_ref,x_len, y_len,index_tot_regions,img_poly = read_xml(xml_file) id_all_text = id_paragraph + id_header co_text_all = co_text_paragraph + co_text_header From fe69b9c4a8428cc6a957f2b40c5aa559dd25416b Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Fri, 21 Jun 2024 23:42:25 +0200 Subject: [PATCH 075/123] update inference --- train/inference.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/train/inference.py b/train/inference.py index c7a8b02..3fec9c2 100644 --- a/train/inference.py +++ b/train/inference.py @@ -557,6 +557,10 @@ class sbb_predict: res=self.predict() if (self.task == 'classification' or self.task == 'reading_order'): pass + elif self.task == 'enhancement': + if self.save: + print(self.save) + cv2.imwrite(self.save,res) else: img_seg_overlayed = self.visualize_model_output(res, self.img_org, self.task) if self.save: From 9260d2962a0fbdcc30ae836d5e21af2122764aa7 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Tue, 9 Jul 2024 03:04:29 +0200 Subject: [PATCH 076/123] resolving typo --- train/gt_gen_utils.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/train/gt_gen_utils.py b/train/gt_gen_utils.py index c2360fc..c264f4c 100644 --- a/train/gt_gen_utils.py +++ b/train/gt_gen_utils.py @@ -304,8 +304,9 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ if config_file and config_params['use_case']=='layout': keys = list(config_params.keys()) - if "artificial_class_on_boundry" in keys: - elements_with_artificial_class = list(config_params['artificial_class_on_boundry']) + + if "artificial_class_on_boundary" in keys: + elements_with_artificial_class = list(config_params['artificial_class_on_boundary']) artificial_class_rgb_color = (255,255,0) artificial_class_label = config_params['artificial_class_label'] #values = config_params.values() @@ -567,8 +568,8 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ elif vv.tag!=link+'Point' and sumi>=1: break co_noise.append(np.array(c_t_in)) - - if "artificial_class_on_boundry" in keys: + + if "artificial_class_on_boundary" in keys: img_boundary = np.zeros( (y_len,x_len) ) if "paragraph" in elements_with_artificial_class: erosion_rate = 2 @@ -655,7 +656,7 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ img_poly=cv2.fillPoly(img, pts =co_text[element_text], color=color_label) - if "artificial_class_on_boundry" in keys: + if "artificial_class_on_boundary" in keys: img_poly[:,:,0][img_boundary[:,:]==1] = artificial_class_rgb_color[0] img_poly[:,:,1][img_boundary[:,:]==1] = artificial_class_rgb_color[1] img_poly[:,:,2][img_boundary[:,:]==1] = artificial_class_rgb_color[2] @@ -706,7 +707,7 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ color_label = config_params['textregions'][element_text] img_poly=cv2.fillPoly(img, pts =co_text[element_text], color=color_label) - if "artificial_class_on_boundry" in keys: + if "artificial_class_on_boundary" in keys: img_poly[:,:][img_boundary[:,:]==1] = artificial_class_label From 3bceec9c19158030acdb59f8f84c2d0d66382414 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Tue, 16 Jul 2024 18:29:27 +0200 Subject: [PATCH 077/123] printspace_as_class_in_layout is integrated. Printspace can be defined as a class for layout segmentation --- train/gt_gen_utils.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/train/gt_gen_utils.py b/train/gt_gen_utils.py index c264f4c..1df7b2a 100644 --- a/train/gt_gen_utils.py +++ b/train/gt_gen_utils.py @@ -154,7 +154,7 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ x_new = columns_width_dict[str(num_col)] y_new = int ( x_new * (y_len / float(x_len)) ) - if printspace: + if printspace or "printspace_as_class_in_layout" in list(config_params.keys()): region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace') or x.endswith('Border')]) co_use_case = [] @@ -279,6 +279,7 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ if printspace and config_params['use_case']!='printspace': img_poly = img_poly[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :] + if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace': img_poly = resize_image(img_poly, y_new, x_new) @@ -310,6 +311,10 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ artificial_class_rgb_color = (255,255,0) artificial_class_label = config_params['artificial_class_label'] #values = config_params.values() + + if "printspace_as_class_in_layout" in list(config_params.keys()): + printspace_class_rgb_color = (125,125,255) + printspace_class_label = config_params['printspace_as_class_in_layout'] if 'textregions' in keys: types_text_dict = config_params['textregions'] @@ -614,7 +619,7 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ - img = np.zeros( (y_len,x_len,3) ) + img = np.zeros( (y_len,x_len,3) ) if output_type == '3d': if 'graphicregions' in keys: @@ -661,6 +666,15 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ img_poly[:,:,1][img_boundary[:,:]==1] = artificial_class_rgb_color[1] img_poly[:,:,2][img_boundary[:,:]==1] = artificial_class_rgb_color[2] + + if "printspace_as_class_in_layout" in list(config_params.keys()): + printspace_mask = np.zeros((img_poly.shape[0], img_poly.shape[1])) + printspace_mask[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2]] = 1 + + img_poly[:,:,0][printspace_mask[:,:] == 0] = printspace_class_rgb_color[0] + img_poly[:,:,1][printspace_mask[:,:] == 0] = printspace_class_rgb_color[1] + img_poly[:,:,2][printspace_mask[:,:] == 0] = printspace_class_rgb_color[2] + @@ -709,6 +723,14 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ if "artificial_class_on_boundary" in keys: img_poly[:,:][img_boundary[:,:]==1] = artificial_class_label + + if "printspace_as_class_in_layout" in list(config_params.keys()): + printspace_mask = np.zeros((img_poly.shape[0], img_poly.shape[1])) + printspace_mask[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2]] = 1 + + img_poly[:,:,0][printspace_mask[:,:] == 0] = printspace_class_label + img_poly[:,:,1][printspace_mask[:,:] == 0] = printspace_class_label + img_poly[:,:,2][printspace_mask[:,:] == 0] = printspace_class_label From 453d0fbf9220122096fd4578695783faa35823b7 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Wed, 17 Jul 2024 17:14:20 +0200 Subject: [PATCH 078/123] adding degrading and brightness augmentation to no patches case training --- train/utils.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/train/utils.py b/train/utils.py index 605d8d1..7a2274c 100644 --- a/train/utils.py +++ b/train/utils.py @@ -597,6 +597,14 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) indexer += 1 + if brightening: + for factor in brightness: + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', + (resize_image(do_brightening(dir_img + '/' +im, factor), input_height, input_width))) + + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', + resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) + indexer += 1 if binarization: cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', @@ -606,6 +614,15 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) indexer += 1 + if degrading: + for degrade_scale_ind in degrade_scales: + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', + (resize_image(do_degrading(cv2.imread(dir_img + '/' + im), degrade_scale_ind), input_height, input_width))) + + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', + resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) + indexer += 1 + if patches: indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, From 861f0b1ebd39d8d2c7d127a0d335f8a3ef17c6e2 Mon Sep 17 00:00:00 2001 From: b-vr103 Date: Wed, 17 Jul 2024 18:20:24 +0200 Subject: [PATCH 079/123] brightness augmentation modified --- train/utils.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/train/utils.py b/train/utils.py index 7a2274c..891ee15 100644 --- a/train/utils.py +++ b/train/utils.py @@ -599,12 +599,15 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow indexer += 1 if brightening: for factor in brightness: - cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', - (resize_image(do_brightening(dir_img + '/' +im, factor), input_height, input_width))) + try: + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', + (resize_image(do_brightening(dir_img + '/' +im, factor), input_height, input_width))) - cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', - resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) - indexer += 1 + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', + resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) + indexer += 1 + except: + pass if binarization: cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', From 840d7c2283d6b71e083c6f10bf3b2e4b8f2e9102 Mon Sep 17 00:00:00 2001 From: b-vr103 Date: Tue, 23 Jul 2024 11:29:05 +0200 Subject: [PATCH 080/123] increasing margin in the case of pixelwise inference --- train/inference.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/train/inference.py b/train/inference.py index 3fec9c2..49bebf8 100644 --- a/train/inference.py +++ b/train/inference.py @@ -219,7 +219,7 @@ class sbb_predict: added_image = cv2.addWeighted(img,0.5,output,0.1,0) - return added_image + return added_image, output def predict(self): self.start_new_session_and_model() @@ -444,7 +444,7 @@ class sbb_predict: if img.shape[1] < self.img_width: img = cv2.resize(img, (self.img_height, img.shape[0]), interpolation=cv2.INTER_NEAREST) - margin = int(0 * self.img_width) + margin = int(0.1 * self.img_width) width_mid = self.img_width - 2 * margin height_mid = self.img_height - 2 * margin img = img / float(255.0) @@ -562,9 +562,10 @@ class sbb_predict: print(self.save) cv2.imwrite(self.save,res) else: - img_seg_overlayed = self.visualize_model_output(res, self.img_org, self.task) + img_seg_overlayed, only_prediction = self.visualize_model_output(res, self.img_org, self.task) if self.save: cv2.imwrite(self.save,img_seg_overlayed) + cv2.imwrite('./layout.png', only_prediction) if self.ground_truth: gt_img=cv2.imread(self.ground_truth) From 2c822dae4e49d970d26a7776e20f55f34144d79e Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Wed, 24 Jul 2024 16:52:05 +0200 Subject: [PATCH 081/123] erosion and dilation parameters are changed & separators are written in label images after artificial label --- train/gt_gen_utils.py | 40 +++++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/train/gt_gen_utils.py b/train/gt_gen_utils.py index 1df7b2a..253c44a 100644 --- a/train/gt_gen_utils.py +++ b/train/gt_gen_utils.py @@ -577,8 +577,8 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ if "artificial_class_on_boundary" in keys: img_boundary = np.zeros( (y_len,x_len) ) if "paragraph" in elements_with_artificial_class: - erosion_rate = 2 - dilation_rate = 4 + erosion_rate = 0#2 + dilation_rate = 3#4 co_text['paragraph'], img_boundary = update_region_contours(co_text['paragraph'], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) if "drop-capital" in elements_with_artificial_class: erosion_rate = 0 @@ -586,35 +586,35 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ co_text["drop-capital"], img_boundary = update_region_contours(co_text["drop-capital"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) if "catch-word" in elements_with_artificial_class: erosion_rate = 0 - dilation_rate = 4 + dilation_rate = 2#4 co_text["catch-word"], img_boundary = update_region_contours(co_text["catch-word"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) if "page-number" in elements_with_artificial_class: erosion_rate = 0 - dilation_rate = 4 + dilation_rate = 2#4 co_text["page-number"], img_boundary = update_region_contours(co_text["page-number"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) if "header" in elements_with_artificial_class: - erosion_rate = 1 - dilation_rate = 4 + erosion_rate = 0#1 + dilation_rate = 3#4 co_text["header"], img_boundary = update_region_contours(co_text["header"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) if "heading" in elements_with_artificial_class: - erosion_rate = 1 - dilation_rate = 4 + erosion_rate = 0#1 + dilation_rate = 3#4 co_text["heading"], img_boundary = update_region_contours(co_text["heading"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) if "signature-mark" in elements_with_artificial_class: erosion_rate = 1 dilation_rate = 4 co_text["signature-mark"], img_boundary = update_region_contours(co_text["signature-mark"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) if "marginalia" in elements_with_artificial_class: - erosion_rate = 2 - dilation_rate = 4 + erosion_rate = 0#2 + dilation_rate = 3#4 co_text["marginalia"], img_boundary = update_region_contours(co_text["marginalia"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) if "footnote" in elements_with_artificial_class: - erosion_rate = 2 - dilation_rate = 4 + erosion_rate = 0#2 + dilation_rate = 2#4 co_text["footnote"], img_boundary = update_region_contours(co_text["footnote"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) if "footnote-continued" in elements_with_artificial_class: - erosion_rate = 2 - dilation_rate = 4 + erosion_rate = 0#2 + dilation_rate = 2#4 co_text["footnote-continued"], img_boundary = update_region_contours(co_text["footnote-continued"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) @@ -639,8 +639,6 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ if 'imageregion' in keys: img_poly=cv2.fillPoly(img, pts =co_img, color=labels_rgb_color[ config_params['imageregion']]) - if 'separatorregion' in keys: - img_poly=cv2.fillPoly(img, pts =co_sep, color=labels_rgb_color[ config_params['separatorregion']]) if 'tableregion' in keys: img_poly=cv2.fillPoly(img, pts =co_table, color=labels_rgb_color[ config_params['tableregion']]) if 'noiseregion' in keys: @@ -666,6 +664,9 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ img_poly[:,:,1][img_boundary[:,:]==1] = artificial_class_rgb_color[1] img_poly[:,:,2][img_boundary[:,:]==1] = artificial_class_rgb_color[2] + if 'separatorregion' in keys: + img_poly=cv2.fillPoly(img, pts =co_sep, color=labels_rgb_color[ config_params['separatorregion']]) + if "printspace_as_class_in_layout" in list(config_params.keys()): printspace_mask = np.zeros((img_poly.shape[0], img_poly.shape[1])) @@ -697,9 +698,6 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ if 'imageregion' in keys: color_label = config_params['imageregion'] img_poly=cv2.fillPoly(img, pts =co_img, color=(color_label,color_label,color_label)) - if 'separatorregion' in keys: - color_label = config_params['separatorregion'] - img_poly=cv2.fillPoly(img, pts =co_sep, color=(color_label,color_label,color_label)) if 'tableregion' in keys: color_label = config_params['tableregion'] img_poly=cv2.fillPoly(img, pts =co_table, color=(color_label,color_label,color_label)) @@ -724,6 +722,10 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ if "artificial_class_on_boundary" in keys: img_poly[:,:][img_boundary[:,:]==1] = artificial_class_label + if 'separatorregion' in keys: + color_label = config_params['separatorregion'] + img_poly=cv2.fillPoly(img, pts =co_sep, color=(color_label,color_label,color_label)) + if "printspace_as_class_in_layout" in list(config_params.keys()): printspace_mask = np.zeros((img_poly.shape[0], img_poly.shape[1])) printspace_mask[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2]] = 1 From 6fb28d6ce8cab024595a8a787d92129fbbeaf3c3 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Thu, 1 Aug 2024 14:30:51 +0200 Subject: [PATCH 082/123] erosion rate changed --- train/gt_gen_utils.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/train/gt_gen_utils.py b/train/gt_gen_utils.py index 253c44a..13010bf 100644 --- a/train/gt_gen_utils.py +++ b/train/gt_gen_utils.py @@ -577,36 +577,36 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ if "artificial_class_on_boundary" in keys: img_boundary = np.zeros( (y_len,x_len) ) if "paragraph" in elements_with_artificial_class: - erosion_rate = 0#2 - dilation_rate = 3#4 + erosion_rate = 2 + dilation_rate = 4 co_text['paragraph'], img_boundary = update_region_contours(co_text['paragraph'], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) if "drop-capital" in elements_with_artificial_class: - erosion_rate = 0 - dilation_rate = 4 + erosion_rate = 1 + dilation_rate = 3 co_text["drop-capital"], img_boundary = update_region_contours(co_text["drop-capital"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) if "catch-word" in elements_with_artificial_class: erosion_rate = 0 - dilation_rate = 2#4 + dilation_rate = 3#4 co_text["catch-word"], img_boundary = update_region_contours(co_text["catch-word"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) if "page-number" in elements_with_artificial_class: erosion_rate = 0 - dilation_rate = 2#4 + dilation_rate = 3#4 co_text["page-number"], img_boundary = update_region_contours(co_text["page-number"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) if "header" in elements_with_artificial_class: - erosion_rate = 0#1 - dilation_rate = 3#4 + erosion_rate = 1 + dilation_rate = 4 co_text["header"], img_boundary = update_region_contours(co_text["header"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) if "heading" in elements_with_artificial_class: - erosion_rate = 0#1 - dilation_rate = 3#4 + erosion_rate = 1 + dilation_rate = 4 co_text["heading"], img_boundary = update_region_contours(co_text["heading"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) if "signature-mark" in elements_with_artificial_class: erosion_rate = 1 dilation_rate = 4 co_text["signature-mark"], img_boundary = update_region_contours(co_text["signature-mark"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) if "marginalia" in elements_with_artificial_class: - erosion_rate = 0#2 - dilation_rate = 3#4 + erosion_rate = 2 + dilation_rate = 4 co_text["marginalia"], img_boundary = update_region_contours(co_text["marginalia"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) if "footnote" in elements_with_artificial_class: erosion_rate = 0#2 From 2d83b8faad8e6e0983529cda221eb17ebb0048f4 Mon Sep 17 00:00:00 2001 From: Clemens Neudecker <952378+cneud@users.noreply.github.com> Date: Thu, 8 Aug 2024 16:35:06 +0200 Subject: [PATCH 083/123] add documentation from wiki as markdown file to the codebase --- train/train.md | 576 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 576 insertions(+) create mode 100644 train/train.md diff --git a/train/train.md b/train/train.md new file mode 100644 index 0000000..553522b --- /dev/null +++ b/train/train.md @@ -0,0 +1,576 @@ +# Documentation for Training Models + +This repository assists users in preparing training datasets, training models, and performing inference with trained models. We cover various use cases including pixel-wise segmentation, image classification, image enhancement, and machine-based reading order. For each use case, we provide guidance on how to generate the corresponding training dataset. +All these use cases are now utilized in the Eynollah workflow. +As mentioned, the following three tasks can be accomplished using this repository: + +* Generate training dataset +* Train a model +* Inference with the trained model + +## Generate training dataset +The script generate_gt_for_training.py is used for generating training datasets. As the results of the following command demonstrate, the dataset generator provides three different commands: + +`python generate_gt_for_training.py --help` + + +These three commands are: + +* image-enhancement +* machine-based-reading-order +* pagexml2label + + +### image-enhancement + +Generating a training dataset for image enhancement is quite straightforward. All that is needed is a set of high-resolution images. The training dataset can then be generated using the following command: + +`python generate_gt_for_training.py image-enhancement -dis "dir of high resolution images" -dois "dir where degraded images will be written" -dols "dir where the corresponding high resolution image will be written as label" -scs "degrading scales json file"` + +The scales JSON file is a dictionary with a key named 'scales' and values representing scales smaller than 1. Images are downscaled based on these scales and then upscaled again to their original size. This process causes the images to lose resolution at different scales. The degraded images are used as input images, and the original high-resolution images serve as labels. The enhancement model can be trained with this generated dataset. The scales JSON file looks like this: + +```yaml +{ + "scales": [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9] +} +``` + +### machine-based-reading-order + +For machine-based reading order, we aim to determine the reading priority between two sets of text regions. The model's input is a three-channel image: the first and last channels contain information about each of the two text regions, while the middle channel encodes prominent layout elements necessary for reading order, such as separators and headers. To generate the training dataset, our script requires a page XML file that specifies the image layout with the correct reading order. + +For output images, it is necessary to specify the width and height. Additionally, a minimum text region size can be set to filter out regions smaller than this minimum size. This minimum size is defined as the ratio of the text region area to the image area, with a default value of zero. To run the dataset generator, use the following command: + + +`python generate_gt_for_training.py machine-based-reading-order -dx "dir of GT xml files" -domi "dir where output images will be written" -docl "dir where the labels will be written" -ih "height" -iw "width" -min "min area ratio"` + +### pagexml2label + +pagexml2label is designed to generate labels from GT page XML files for various pixel-wise segmentation use cases, including 'layout,' 'textline,' 'printspace,' 'glyph,' and 'word' segmentation. +To train a pixel-wise segmentation model, we require images along with their corresponding labels. Our training script expects a PNG image where each pixel corresponds to a label, represented by an integer. The background is always labeled as zero, while other elements are assigned different integers. For instance, if we have ground truth data with four elements including the background, the classes would be labeled as 0, 1, 2, and 3 respectively. + +In binary segmentation scenarios such as textline or page extraction, the background is encoded as 0, and the desired element is automatically encoded as 1 in the PNG label. + +To specify the desired use case and the elements to be extracted in the PNG labels, a custom JSON file can be passed. For example, in the case of 'textline' detection, the JSON file would resemble this: + +```yaml +{ +"use_case": "textline" +} +``` + +In the case of layout segmentation a possible custom config json file can be like this: + +```yaml +{ +"use_case": "layout", +"textregions":{"rest_as_paragraph":1 , "drop-capital": 1, "header":2, "heading":2, "marginalia":3}, +"imageregion":4, +"separatorregion":5, +"graphicregions" :{"rest_as_decoration":6 ,"stamp":7} +} +``` + +A possible custom config json file for layout segmentation where the "printspace" is wished to be a class: + +```yaml +{ +"use_case": "layout", +"textregions":{"rest_as_paragraph":1 , "drop-capital": 1, "header":2, "heading":2, "marginalia":3}, +"imageregion":4, +"separatorregion":5, +"graphicregions" :{"rest_as_decoration":6 ,"stamp":7} +"printspace_as_class_in_layout" : 8 +} +``` +For the layout use case, it is beneficial to first understand the structure of the page XML file and its elements. In a given image, the annotations of elements are recorded in a page XML file, including their contours and classes. For an image document, the known regions are 'textregion', 'separatorregion', 'imageregion', 'graphicregion', 'noiseregion', and 'tableregion'. + +Text regions and graphic regions also have their own specific types. The known types for us for text regions are 'paragraph', 'header', 'heading', 'marginalia', 'drop-capital', 'footnote', 'footnote-continued', 'signature-mark', 'page-number', and 'catch-word'. The known types for graphic regions are 'handwritten-annotation', 'decoration', 'stamp', and 'signature'. +Since we don't know all types of text and graphic regions, unknown cases can arise. To handle these, we have defined two additional types: "rest_as_paragraph" and "rest_as_decoration" to ensure that no unknown types are missed. This way, users can extract all known types from the labels and be confident that no unknown types are overlooked. + +In the custom JSON file shown above, "header" and "heading" are extracted as the same class, while "marginalia" is shown as a different class. All other text region types, including "drop-capital," are grouped into the same class. For the graphic region, "stamp" has its own class, while all other types are classified together. "Image region" and "separator region" are also present in the label. However, other regions like "noise region" and "table region" will not be included in the label PNG file, even if they have information in the page XML files, as we chose not to include them. + +`python generate_gt_for_training.py pagexml2label -dx "dir of GT xml files" -do "dir where output label png files will be written" -cfg "custom config json file" -to "output type which has 2d and 3d. 2d is used for training and 3d is just to visualise the labels" "` + +We have also defined an artificial class that can be added to the boundary of text region types or text lines. This key is called "artificial_class_on_boundary." If users want to apply this to certain text regions in the layout use case, the example JSON config file should look like this: + +```yaml +{ + "use_case": "layout", + "textregions": { + "paragraph": 1, + "drop-capital": 1, + "header": 2, + "heading": 2, + "marginalia": 3 + }, + "imageregion": 4, + "separatorregion": 5, + "graphicregions": { + "rest_as_decoration": 6 + }, + "artificial_class_on_boundary": ["paragraph", "header", "heading", "marginalia"], + "artificial_class_label": 7 +} +``` + +This implies that the artificial class label, denoted by 7, will be present on PNG files and will only be added to the elements labeled as "paragraph," "header," "heading," and "marginalia." + +For "textline," "word," and "glyph," the artificial class on the boundaries will be activated only if the "artificial_class_label" key is specified in the config file. Its value should be set as 2 since these elements represent binary cases. For example, if the background and textline are denoted as 0 and 1 respectively, then the artificial class should be assigned the value 2. The example JSON config file should look like this for "textline" use case: + +```yaml +{ + "use_case": "textline", + "artificial_class_label": 2 +} +``` + +If the coordinates of "PrintSpace" or "Border" are present in the page XML ground truth files, and the user wishes to crop only the print space area, this can be achieved by activating the "-ps" argument. However, it should be noted that in this scenario, since cropping will be applied to the label files, the directory of the original images must be provided to ensure that they are cropped in sync with the labels. This ensures that the correct images and labels required for training are obtained. The command should resemble the following: + +`python generate_gt_for_training.py pagexml2label -dx "dir of GT xml files" -do "dir where output label png files will be written" -cfg "custom config json file" -to "output type which has 2d and 3d. 2d is used for training and 3d is just to visualise the labels" -ps -di "dir where the org images are located" -doi "dir where the cropped output images will be written" ` + +## Train a model +### classification + +For the classification use case, we haven't provided a ground truth generator, as it's unnecessary. For classification, all we require is a training directory with subdirectories, each containing images of its respective classes. We need separate directories for training and evaluation, and the class names (subdirectories) must be consistent across both directories. Additionally, the class names should be specified in the config JSON file, as shown in the following example. If, for instance, we aim to classify "apple" and "orange," with a total of 2 classes, the "classification_classes_name" key in the config file should appear as follows: + +```yaml +{ + "backbone_type" : "nontransformer", + "task": "classification", + "n_classes" : 2, + "n_epochs" : 10, + "input_height" : 448, + "input_width" : 448, + "weight_decay" : 1e-6, + "n_batch" : 4, + "learning_rate": 1e-4, + "f1_threshold_classification": 0.8, + "pretraining" : true, + "classification_classes_name" : {"0":"apple", "1":"orange"}, + "dir_train": "./train", + "dir_eval": "./eval", + "dir_output": "./output" +} +``` + +The "dir_train" should be like this: + +``` +. +└── train # train directory + ├── apple # directory of images for apple class + └── orange # directory of images for orange class +``` + +And the "dir_eval" the same structure as train directory: + +``` +. +└── eval # evaluation directory + ├── apple # directory of images for apple class + └── orange # directory of images for orange class + +``` + +The classification model can be trained using the following command line: + +`python train.py with config_classification.json` + + +As evident in the example JSON file above, for classification, we utilize a "f1_threshold_classification" parameter. This parameter is employed to gather all models with an evaluation f1 score surpassing this threshold. Subsequently, an ensemble of these model weights is executed, and a model is saved in the output directory as "model_ens_avg". Additionally, the weight of the best model based on the evaluation f1 score is saved as "model_best". + +### reading order +An example config json file for machine based reading order should be like this: + +```yaml +{ + "backbone_type" : "nontransformer", + "task": "reading_order", + "n_classes" : 1, + "n_epochs" : 5, + "input_height" : 672, + "input_width" : 448, + "weight_decay" : 1e-6, + "n_batch" : 4, + "learning_rate": 1e-4, + "pretraining" : true, + "dir_train": "./train", + "dir_eval": "./eval", + "dir_output": "./output" +} +``` + +The "dir_train" should be like this: + +``` +. +└── train # train directory + ├── images # directory of images + └── labels # directory of labels +``` + +And the "dir_eval" the same structure as train directory: + +``` +. +└── eval # evaluation directory + ├── images # directory of images + └── labels # directory of labels +``` + +The classification model can be trained like the classification case command line. + +### Segmentation (Textline, Binarization, Page extraction and layout) and enhancement + +#### Parameter configuration for segmentation or enhancement usecases + +The following parameter configuration can be applied to all segmentation use cases and enhancements. The augmentation, its sub-parameters, and continued training are defined only for segmentation use cases and enhancements, not for classification and machine-based reading order, as you can see in their example config files. + +* backbone_type: For segmentation tasks (such as text line, binarization, and layout detection) and enhancement, we offer two backbone options: a "nontransformer" and a "transformer" backbone. For the "transformer" backbone, we first apply a CNN followed by a transformer. In contrast, the "nontransformer" backbone utilizes only a CNN ResNet-50. +* task : The task parameter can have values such as "segmentation", "enhancement", "classification", and "reading_order". +* patches: If you want to break input images into smaller patches (input size of the model) you need to set this parameter to ``true``. In the case that the model should see the image once, like page extraction, patches should be set to ``false``. +* n_batch: Number of batches at each iteration. +* n_classes: Number of classes. In the case of binary classification this should be 2. In the case of reading_order it should set to 1. And for the case of layout detection just the unique number of classes should be given. +* n_epochs: Number of epochs. +* input_height: This indicates the height of model's input. +* input_width: This indicates the width of model's input. +* weight_decay: Weight decay of l2 regularization of model layers. +* pretraining: Set to ``true`` to load pretrained weights of ResNet50 encoder. The downloaded weights should be saved in a folder named "pretrained_model" in the same directory of "train.py" script. +* augmentation: If you want to apply any kind of augmentation this parameter should first set to ``true``. +* flip_aug: If ``true``, different types of filp will be applied on image. Type of flips is given with "flip_index" parameter. +* blur_aug: If ``true``, different types of blurring will be applied on image. Type of blurrings is given with "blur_k" parameter. +* scaling: If ``true``, scaling will be applied on image. Scale of scaling is given with "scales" parameter. +* degrading: If ``true``, degrading will be applied to the image. The amount of degrading is defined with "degrade_scales" parameter. +* brightening: If ``true``, brightening will be applied to the image. The amount of brightening is defined with "brightness" parameter. +* rotation_not_90: If ``true``, rotation (not 90 degree) will be applied on image. Rotation angles are given with "thetha" parameter. +* rotation: If ``true``, 90 degree rotation will be applied on image. +* binarization: If ``true``,Otsu thresholding will be applied to augment the input data with binarized images. +* scaling_bluring: If ``true``, combination of scaling and blurring will be applied on image. +* scaling_binarization: If ``true``, combination of scaling and binarization will be applied on image. +* scaling_flip: If ``true``, combination of scaling and flip will be applied on image. +* flip_index: Type of flips. +* blur_k: Type of blurrings. +* scales: Scales of scaling. +* brightness: The amount of brightenings. +* thetha: Rotation angles. +* degrade_scales: The amount of degradings. +* continue_training: If ``true``, it means that you have already trained a model and you would like to continue the training. So it is needed to provide the dir of trained model with "dir_of_start_model" and index for naming the models. For example if you have already trained for 3 epochs then your last index is 2 and if you want to continue from model_1.h5, you can set ``index_start`` to 3 to start naming model with index 3. +* weighted_loss: If ``true``, this means that you want to apply weighted categorical_crossentropy as loss fucntion. Be carefull if you set to ``true``the parameter "is_loss_soft_dice" should be ``false`` +* data_is_provided: If you have already provided the input data you can set this to ``true``. Be sure that the train and eval data are in "dir_output". Since when once we provide training data we resize and augment them and then we write them in sub-directories train and eval in "dir_output". +* dir_train: This is the directory of "images" and "labels" (dir_train should include two subdirectories with names of images and labels ) for raw images and labels. Namely they are not prepared (not resized and not augmented) yet for training the model. When we run this tool these raw data will be transformed to suitable size needed for the model and they will be written in "dir_output" in train and eval directories. Each of train and eval include "images" and "labels" sub-directories. +* index_start: Starting index for saved models in the case that "continue_training" is ``true``. +* dir_of_start_model: Directory containing pretrained model to continue training the model in the case that "continue_training" is ``true``. +* transformer_num_patches_xy: Number of patches for vision transformer in x and y direction respectively. +* transformer_patchsize_x: Patch size of vision transformer patches in x direction. +* transformer_patchsize_y: Patch size of vision transformer patches in y direction. +* transformer_projection_dim: Transformer projection dimension. Default value is 64. +* transformer_mlp_head_units: Transformer Multilayer Perceptron (MLP) head units. Default value is [128, 64]. +* transformer_layers: transformer layers. Default value is 8. +* transformer_num_heads: Transformer number of heads. Default value is 4. +* transformer_cnn_first: We have two types of vision transformers. In one type, a CNN is applied first, followed by a transformer. In the other type, this order is reversed. If transformer_cnn_first is true, it means the CNN will be applied before the transformer. Default value is true. + +In the case of segmentation and enhancement the train and evaluation directory should be as following. + +The "dir_train" should be like this: + +``` +. +└── train # train directory + ├── images # directory of images + └── labels # directory of labels +``` + +And the "dir_eval" the same structure as train directory: + +``` +. +└── eval # evaluation directory + ├── images # directory of images + └── labels # directory of labels +``` + +After configuring the JSON file for segmentation or enhancement, training can be initiated by running the following command, similar to the process for classification and reading order: + +`python train.py with config_classification.json` + +#### Binarization + +An example config json file for binarization can be like this: + +```yaml +{ + "backbone_type" : "transformer", + "task": "binarization", + "n_classes" : 2, + "n_epochs" : 4, + "input_height" : 224, + "input_width" : 672, + "weight_decay" : 1e-6, + "n_batch" : 1, + "learning_rate": 1e-4, + "patches" : true, + "pretraining" : true, + "augmentation" : true, + "flip_aug" : false, + "blur_aug" : false, + "scaling" : true, + "degrading": false, + "brightening": false, + "binarization" : false, + "scaling_bluring" : false, + "scaling_binarization" : false, + "scaling_flip" : false, + "rotation": false, + "rotation_not_90": false, + "transformer_num_patches_xy": [7, 7], + "transformer_patchsize_x": 3, + "transformer_patchsize_y": 1, + "transformer_projection_dim": 192, + "transformer_mlp_head_units": [128, 64], + "transformer_layers": 8, + "transformer_num_heads": 4, + "transformer_cnn_first": true, + "blur_k" : ["blur","guass","median"], + "scales" : [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.4], + "brightness" : [1.3, 1.5, 1.7, 2], + "degrade_scales" : [0.2, 0.4], + "flip_index" : [0, 1, -1], + "thetha" : [10, -10], + "continue_training": false, + "index_start" : 0, + "dir_of_start_model" : " ", + "weighted_loss": false, + "is_loss_soft_dice": false, + "data_is_provided": false, + "dir_train": "./train", + "dir_eval": "./eval", + "dir_output": "./output" +} +``` + +#### Textline + +```yaml +{ + "backbone_type" : "nontransformer", + "task": "segmentation", + "n_classes" : 2, + "n_epochs" : 4, + "input_height" : 448, + "input_width" : 224, + "weight_decay" : 1e-6, + "n_batch" : 1, + "learning_rate": 1e-4, + "patches" : true, + "pretraining" : true, + "augmentation" : true, + "flip_aug" : false, + "blur_aug" : false, + "scaling" : true, + "degrading": false, + "brightening": false, + "binarization" : false, + "scaling_bluring" : false, + "scaling_binarization" : false, + "scaling_flip" : false, + "rotation": false, + "rotation_not_90": false, + "blur_k" : ["blur","guass","median"], + "scales" : [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.4], + "brightness" : [1.3, 1.5, 1.7, 2], + "degrade_scales" : [0.2, 0.4], + "flip_index" : [0, 1, -1], + "thetha" : [10, -10], + "continue_training": false, + "index_start" : 0, + "dir_of_start_model" : " ", + "weighted_loss": false, + "is_loss_soft_dice": false, + "data_is_provided": false, + "dir_train": "./train", + "dir_eval": "./eval", + "dir_output": "./output" +} +``` + +#### Enhancement + +```yaml +{ + "backbone_type" : "nontransformer", + "task": "enhancement", + "n_classes" : 3, + "n_epochs" : 4, + "input_height" : 448, + "input_width" : 224, + "weight_decay" : 1e-6, + "n_batch" : 4, + "learning_rate": 1e-4, + "patches" : true, + "pretraining" : true, + "augmentation" : true, + "flip_aug" : false, + "blur_aug" : false, + "scaling" : true, + "degrading": false, + "brightening": false, + "binarization" : false, + "scaling_bluring" : false, + "scaling_binarization" : false, + "scaling_flip" : false, + "rotation": false, + "rotation_not_90": false, + "blur_k" : ["blur","guass","median"], + "scales" : [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.4], + "brightness" : [1.3, 1.5, 1.7, 2], + "degrade_scales" : [0.2, 0.4], + "flip_index" : [0, 1, -1], + "thetha" : [10, -10], + "continue_training": false, + "index_start" : 0, + "dir_of_start_model" : " ", + "weighted_loss": false, + "is_loss_soft_dice": false, + "data_is_provided": false, + "dir_train": "./train", + "dir_eval": "./eval", + "dir_output": "./output" +} +``` + +It's important to mention that the value of n_classes for enhancement should be 3, as the model's output is a 3-channel image. + +#### Page extraction + +```yaml +{ + "backbone_type" : "nontransformer", + "task": "segmentation", + "n_classes" : 2, + "n_epochs" : 4, + "input_height" : 448, + "input_width" : 224, + "weight_decay" : 1e-6, + "n_batch" : 1, + "learning_rate": 1e-4, + "patches" : false, + "pretraining" : true, + "augmentation" : false, + "flip_aug" : false, + "blur_aug" : false, + "scaling" : true, + "degrading": false, + "brightening": false, + "binarization" : false, + "scaling_bluring" : false, + "scaling_binarization" : false, + "scaling_flip" : false, + "rotation": false, + "rotation_not_90": false, + "blur_k" : ["blur","guass","median"], + "scales" : [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.4], + "brightness" : [1.3, 1.5, 1.7, 2], + "degrade_scales" : [0.2, 0.4], + "flip_index" : [0, 1, -1], + "thetha" : [10, -10], + "continue_training": false, + "index_start" : 0, + "dir_of_start_model" : " ", + "weighted_loss": false, + "is_loss_soft_dice": false, + "data_is_provided": false, + "dir_train": "./train", + "dir_eval": "./eval", + "dir_output": "./output" +} +``` + +For page segmentation (or print space or border segmentation), the model needs to view the input image in its entirety, hence the patches parameter should be set to false. + +#### layout segmentation + +An example config json file for layout segmentation with 5 classes (including background) can be like this: + +```yaml +{ + "backbone_type" : "transformer", + "task": "segmentation", + "n_classes" : 5, + "n_epochs" : 4, + "input_height" : 448, + "input_width" : 224, + "weight_decay" : 1e-6, + "n_batch" : 1, + "learning_rate": 1e-4, + "patches" : true, + "pretraining" : true, + "augmentation" : true, + "flip_aug" : false, + "blur_aug" : false, + "scaling" : true, + "degrading": false, + "brightening": false, + "binarization" : false, + "scaling_bluring" : false, + "scaling_binarization" : false, + "scaling_flip" : false, + "rotation": false, + "rotation_not_90": false, + "transformer_num_patches_xy": [7, 14], + "transformer_patchsize_x": 1, + "transformer_patchsize_y": 1, + "transformer_projection_dim": 64, + "transformer_mlp_head_units": [128, 64], + "transformer_layers": 8, + "transformer_num_heads": 4, + "transformer_cnn_first": true, + "blur_k" : ["blur","guass","median"], + "scales" : [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.4], + "brightness" : [1.3, 1.5, 1.7, 2], + "degrade_scales" : [0.2, 0.4], + "flip_index" : [0, 1, -1], + "thetha" : [10, -10], + "continue_training": false, + "index_start" : 0, + "dir_of_start_model" : " ", + "weighted_loss": false, + "is_loss_soft_dice": false, + "data_is_provided": false, + "dir_train": "./train", + "dir_eval": "./eval", + "dir_output": "./output" +} +``` +## Inference with the trained model +### classification + +For conducting inference with a trained model, you simply need to execute the following command line, specifying the directory of the model and the image on which to perform inference: + + +`python inference.py -m "model dir" -i "image" ` + +This will straightforwardly return the class of the image. + +### machine based reading order + + +To infer the reading order using an reading order model, we need a page XML file containing layout information but without the reading order. We simply need to provide the model directory, the XML file, and the output directory. The new XML file with the added reading order will be written to the output directory with the same name. We need to run: + +`python inference.py -m "model dir" -xml "page xml file" -o "output dir to write new xml with reading order" ` + + +### Segmentation (Textline, Binarization, Page extraction and layout) and enhancement + +For conducting inference with a trained model for segmentation and enhancement you need to run the following command line: + + +`python inference.py -m "model dir" -i "image" -p -s "output image" ` + + +Note that in the case of page extraction the -p flag is not needed. + +For segmentation or binarization tasks, if a ground truth (GT) label is available, the IOU evaluation metric can be calculated for the output. To do this, you need to provide the GT label using the argument -gt. + + + From 3b90347a94521f6ed935ab1a94b39fe9504442ce Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Fri, 9 Aug 2024 12:46:18 +0200 Subject: [PATCH 084/123] save only layout output. different from overlayed layout on image --- train/inference.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/train/inference.py b/train/inference.py index 49bebf8..6054b01 100644 --- a/train/inference.py +++ b/train/inference.py @@ -32,6 +32,7 @@ class sbb_predict: self.image=image self.patches=patches self.save=save + self.save_layout=save_layout self.model_dir=model self.ground_truth=ground_truth self.task=task @@ -181,6 +182,7 @@ class sbb_predict: prediction = prediction * -1 prediction = prediction + 1 added_image = prediction * 255 + layout_only = None else: unique_classes = np.unique(prediction[:,:,0]) rgb_colors = {'0' : [255, 255, 255], @@ -200,26 +202,26 @@ class sbb_predict: '14' : [255, 125, 125], '15' : [255, 0, 255]} - output = np.zeros(prediction.shape) + layout_only = np.zeros(prediction.shape) for unq_class in unique_classes: rgb_class_unique = rgb_colors[str(int(unq_class))] - output[:,:,0][prediction[:,:,0]==unq_class] = rgb_class_unique[0] - output[:,:,1][prediction[:,:,0]==unq_class] = rgb_class_unique[1] - output[:,:,2][prediction[:,:,0]==unq_class] = rgb_class_unique[2] + layout_only[:,:,0][prediction[:,:,0]==unq_class] = rgb_class_unique[0] + layout_only[:,:,1][prediction[:,:,0]==unq_class] = rgb_class_unique[1] + layout_only[:,:,2][prediction[:,:,0]==unq_class] = rgb_class_unique[2] - img = self.resize_image(img, output.shape[0], output.shape[1]) + img = self.resize_image(img, layout_only.shape[0], layout_only.shape[1]) - output = output.astype(np.int32) + layout_only = layout_only.astype(np.int32) img = img.astype(np.int32) - added_image = cv2.addWeighted(img,0.5,output,0.1,0) + added_image = cv2.addWeighted(img,0.5,layout_only,0.1,0) - return added_image, output + return added_image, layout_only def predict(self): self.start_new_session_and_model() @@ -559,13 +561,12 @@ class sbb_predict: pass elif self.task == 'enhancement': if self.save: - print(self.save) cv2.imwrite(self.save,res) else: - img_seg_overlayed, only_prediction = self.visualize_model_output(res, self.img_org, self.task) + img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task) if self.save: cv2.imwrite(self.save,img_seg_overlayed) - cv2.imwrite('./layout.png', only_prediction) + cv2.imwrite(self.save_layout, only_layout) if self.ground_truth: gt_img=cv2.imread(self.ground_truth) @@ -595,6 +596,11 @@ class sbb_predict: "-s", help="save prediction as a png file in current folder.", ) +@click.option( + "--save_layout", + "-sl", + help="save layout prediction only as a png file in current folder.", +) @click.option( "--model", "-m", @@ -618,7 +624,7 @@ class sbb_predict: "-min", help="min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order.", ) -def main(image, model, patches, save, ground_truth, xml_file, out, min_area): +def main(image, model, patches, save, save_layout, ground_truth, xml_file, out, min_area): with open(os.path.join(model,'config.json')) as f: config_params_model = json.load(f) task = config_params_model['task'] @@ -626,7 +632,7 @@ def main(image, model, patches, save, ground_truth, xml_file, out, min_area): if not save: print("Error: You used one of segmentation or binarization task but not set -s, you need a filename to save visualized output with -s") sys.exit(1) - x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth, xml_file, out, min_area) + x=sbb_predict(image, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area) x.run() if __name__=="__main__": From bf5837bf6e4c44add1d401a9912fd1bd599df780 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Fri, 9 Aug 2024 13:20:09 +0200 Subject: [PATCH 085/123] update --- train/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train/inference.py b/train/inference.py index 6054b01..8d0a572 100644 --- a/train/inference.py +++ b/train/inference.py @@ -28,7 +28,7 @@ Tool to load model and predict for given image. """ class sbb_predict: - def __init__(self,image, model, task, config_params_model, patches, save, ground_truth, xml_file, out, min_area): + def __init__(self,image, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area): self.image=image self.patches=patches self.save=save From 5e1821a7419bc20ff760eafccfb940b0c4938eb5 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Wed, 21 Aug 2024 00:48:30 +0200 Subject: [PATCH 086/123] augmentation function for red textlines, rgb background and scaling for no patch case --- train/utils.py | 70 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/train/utils.py b/train/utils.py index 891ee15..2278849 100644 --- a/train/utils.py +++ b/train/utils.py @@ -12,6 +12,76 @@ from tensorflow.keras.utils import to_categorical from PIL import Image, ImageEnhance +def return_shuffled_channels(img, channels_order): + """ + channels order in ordinary case is like this [0, 1, 2]. In the case of shuffling the order should be provided. + """ + img_sh = np.copy(img) + + img_sh[:,:,0]= img[:,:,channels_order[0]] + img_sh[:,:,1]= img[:,:,channels_order[1]] + img_sh[:,:,2]= img[:,:,channels_order[2]] + return img_sh + +def return_binary_image_with_red_textlines(img_bin): + img_red = np.copy(img_bin) + + img_red[:,:,0][img_bin[:,:,0] == 0] = 255 + return img_red + +def return_binary_image_with_given_rgb_background(img_bin, img_rgb_background): + img_rgb_background = resize_image(img_rgb_background ,img_bin.shape[0], img_bin.shape[1]) + + img_final = np.copy(img_bin) + + img_final[:,:,0][img_bin[:,:,0] != 0] = img_rgb_background[:,:,0][img_bin[:,:,0] != 0] + img_final[:,:,1][img_bin[:,:,1] != 0] = img_rgb_background[:,:,1][img_bin[:,:,1] != 0] + img_final[:,:,2][img_bin[:,:,2] != 0] = img_rgb_background[:,:,2][img_bin[:,:,2] != 0] + + return img_final + +def return_binary_image_with_given_rgb_background_red_textlines(img_bin, img_rgb_background, img_color): + img_rgb_background = resize_image(img_rgb_background ,img_bin.shape[0], img_bin.shape[1]) + + img_final = np.copy(img_color) + + img_final[:,:,0][img_bin[:,:,0] != 0] = img_rgb_background[:,:,0][img_bin[:,:,0] != 0] + img_final[:,:,1][img_bin[:,:,1] != 0] = img_rgb_background[:,:,1][img_bin[:,:,1] != 0] + img_final[:,:,2][img_bin[:,:,2] != 0] = img_rgb_background[:,:,2][img_bin[:,:,2] != 0] + + return img_final + +def scale_image_for_no_patch(img, label, scale): + h_n = int(img.shape[0]*scale) + w_n = int(img.shape[1]*scale) + + channel0_avg = int( np.mean(img[:,:,0]) ) + channel1_avg = int( np.mean(img[:,:,1]) ) + channel2_avg = int( np.mean(img[:,:,2]) ) + + h_diff = img.shape[0] - h_n + w_diff = img.shape[1] - w_n + + h_start = int(h_diff / 2.) + w_start = int(w_diff / 2.) + + img_res = resize_image(img, h_n, w_n) + label_res = resize_image(label, h_n, w_n) + + img_scaled_padded = np.copy(img) + + label_scaled_padded = np.zeros(label.shape) + + img_scaled_padded[:,:,0] = channel0_avg + img_scaled_padded[:,:,1] = channel1_avg + img_scaled_padded[:,:,2] = channel2_avg + + img_scaled_padded[h_start:h_start+h_n, w_start:w_start+w_n,:] = img_res[:,:,:] + label_scaled_padded[h_start:h_start+h_n, w_start:w_start+w_n,:] = label_res[:,:,:] + + return img_scaled_padded, label_scaled_padded + + def return_number_of_total_training_data(path_classes): sub_classes = os.listdir(path_classes) n_tot = 0 From 445c45cb87935b73099d1753957c4c6c6eac32f2 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Wed, 21 Aug 2024 16:17:59 +0200 Subject: [PATCH 087/123] updating augmentations --- train/train.py | 8 +++++--- train/utils.py | 41 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/train/train.py b/train/train.py index 71f31f3..fa08a98 100644 --- a/train/train.py +++ b/train/train.py @@ -53,6 +53,7 @@ def config_params(): degrading = False # If true, degrading will be applied to the image. The amount of degrading is defined with "degrade_scales" in config_params.json. brightening = False # If true, brightening will be applied to the image. The amount of brightening is defined with "brightness" in config_params.json. binarization = False # If true, Otsu thresholding will be applied to augment the input with binarized images. + rgb_background = False dir_train = None # Directory of training dataset with subdirectories having the names "images" and "labels". dir_eval = None # Directory of validation dataset with subdirectories having the names "images" and "labels". dir_output = None # Directory where the output model will be saved. @@ -95,7 +96,7 @@ def run(_config, n_classes, n_epochs, input_height, index_start, dir_of_start_model, is_loss_soft_dice, n_batch, patches, augmentation, flip_aug, blur_aug, padding_white, padding_black, scaling, degrading, - brightening, binarization, blur_k, scales, degrade_scales, + brightening, binarization, rgb_background, blur_k, scales, degrade_scales, brightness, dir_train, data_is_provided, scaling_bluring, scaling_brightness, scaling_binarization, rotation, rotation_not_90, thetha, scaling_flip, continue_training, transformer_projection_dim, @@ -108,6 +109,7 @@ def run(_config, n_classes, n_epochs, input_height, if data_is_provided: dir_train_flowing = os.path.join(dir_output, 'train') dir_eval_flowing = os.path.join(dir_output, 'eval') + dir_flow_train_imgs = os.path.join(dir_train_flowing, 'images') dir_flow_train_labels = os.path.join(dir_train_flowing, 'labels') @@ -161,7 +163,7 @@ def run(_config, n_classes, n_epochs, input_height, # writing patches into a sub-folder in order to be flowed from directory. provide_patches(imgs_list, segs_list, dir_img, dir_seg, dir_flow_train_imgs, dir_flow_train_labels, input_height, input_width, blur_k, - blur_aug, padding_white, padding_black, flip_aug, binarization, + blur_aug, padding_white, padding_black, flip_aug, binarization, rgb_background, scaling, degrading, brightening, scales, degrade_scales, brightness, flip_index, scaling_bluring, scaling_brightness, scaling_binarization, rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=augmentation, @@ -169,7 +171,7 @@ def run(_config, n_classes, n_epochs, input_height, provide_patches(imgs_list_test, segs_list_test, dir_img_val, dir_seg_val, dir_flow_eval_imgs, dir_flow_eval_labels, input_height, input_width, - blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization, + blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization, rgb_background, scaling, degrading, brightening, scales, degrade_scales, brightness, flip_index, scaling_bluring, scaling_brightness, scaling_binarization, rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=False, patches=patches) diff --git a/train/utils.py b/train/utils.py index 2278849..cf7a65c 100644 --- a/train/utils.py +++ b/train/utils.py @@ -695,6 +695,47 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) indexer += 1 + + if rotation_not_90: + for thetha_i in thetha: + img_max_rotated, label_max_rotated = rotation_not_90_func(cv2.imread(dir_img + '/'+im), + cv2.imread(dir_of_label_file), thetha_i) + + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(img_max_rotated, input_height, input_width)) + + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(label_max_rotated, input_height, input_width)) + indexer += 1 + + if channels_shuffling: + for shuffle_index in shuffle_indexes: + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', + (resize_image(return_shuffled_channels(cv2.imread(dir_img + '/' + im), shuffle_index), input_height, input_width))) + + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', + resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) + indexer += 1 + + if scaling: + for sc_ind in scales: + img_scaled, label_scaled = scale_image_for_no_patch(cv2.imread(dir_img + '/'+im), + cv2.imread(dir_of_label_file), sc_ind) + + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(img_scaled, input_height, input_width)) + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(label_scaled, input_height, input_width)) + indexer += 1 + + if rgb_color_background: + img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') + for i_n in range(number_of_backgrounds_per_image): + background_image_chosen_name = random.choice(list_all_possible_background_images) + img_rgb_background_chosen = cv2.imread(dir_rgb_backgrounds + '/' + background_image_chosen_name) + img_with_overlayed_background = return_binary_image_with_given_rgb_background(img_bin_corr, img_rgb_background) + + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(img_with_overlayed_background, input_height, input_width)) + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', + resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) + + if patches: From aeb2ee4e3ef404b0fef2414462b9e51e9036bc18 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Wed, 21 Aug 2024 19:33:23 +0200 Subject: [PATCH 088/123] scaling, channels shuffling, rgb background and red content added to no patch augmentation --- train/config_params.json | 30 +++++++++++++++++++----------- train/train.py | 32 ++++++++++++++++++++++---------- train/utils.py | 32 +++++++++++++++++++++++++++----- 3 files changed, 68 insertions(+), 26 deletions(-) diff --git a/train/config_params.json b/train/config_params.json index a89cbb5..e5f652d 100644 --- a/train/config_params.json +++ b/train/config_params.json @@ -1,19 +1,22 @@ { "backbone_type" : "transformer", - "task": "binarization", + "task": "segmentation", "n_classes" : 2, - "n_epochs" : 2, - "input_height" : 224, - "input_width" : 224, + "n_epochs" : 0, + "input_height" : 448, + "input_width" : 448, "weight_decay" : 1e-6, "n_batch" : 1, "learning_rate": 1e-4, - "patches" : true, + "patches" : false, "pretraining" : true, - "augmentation" : false, + "augmentation" : true, "flip_aug" : false, "blur_aug" : false, "scaling" : true, + "adding_rgb_background": true, + "add_red_textlines": true, + "channels_shuffling": true, "degrading": false, "brightening": false, "binarization" : false, @@ -31,18 +34,23 @@ "transformer_num_heads": 1, "transformer_cnn_first": false, "blur_k" : ["blur","guass","median"], - "scales" : [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.4], + "scales" : [0.6, 0.7, 0.8, 0.9], "brightness" : [1.3, 1.5, 1.7, 2], "degrade_scales" : [0.2, 0.4], "flip_index" : [0, 1, -1], - "thetha" : [10, -10], + "shuffle_indexes" : [ [0,2,1], [1,2,0], [1,0,2] , [2,1,0]], + "thetha" : [5, -5], + "number_of_backgrounds_per_image": 2, "continue_training": false, "index_start" : 0, "dir_of_start_model" : " ", "weighted_loss": false, "is_loss_soft_dice": false, "data_is_provided": false, - "dir_train": "./train", - "dir_eval": "./eval", - "dir_output": "./output" + "dir_train": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/train_new", + "dir_eval": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/eval_new", + "dir_output": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/output_new", + "dir_rgb_backgrounds": "/home/vahid/Documents/1_2_test_eynollah/set_rgb_background", + "dir_img_bin": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/train_new/images_bin" + } diff --git a/train/train.py b/train/train.py index fa08a98..5dfad07 100644 --- a/train/train.py +++ b/train/train.py @@ -53,7 +53,9 @@ def config_params(): degrading = False # If true, degrading will be applied to the image. The amount of degrading is defined with "degrade_scales" in config_params.json. brightening = False # If true, brightening will be applied to the image. The amount of brightening is defined with "brightness" in config_params.json. binarization = False # If true, Otsu thresholding will be applied to augment the input with binarized images. - rgb_background = False + adding_rgb_background = False + add_red_textlines = False + channels_shuffling = False dir_train = None # Directory of training dataset with subdirectories having the names "images" and "labels". dir_eval = None # Directory of validation dataset with subdirectories having the names "images" and "labels". dir_output = None # Directory where the output model will be saved. @@ -65,6 +67,7 @@ def config_params(): scaling_brightness = False # If true, a combination of scaling and brightening will be applied to the image. scaling_flip = False # If true, a combination of scaling and flipping will be applied to the image. thetha = None # Rotate image by these angles for augmentation. + shuffle_indexes = None blur_k = None # Blur image for augmentation. scales = None # Scale patches for augmentation. degrade_scales = None # Degrade image for augmentation. @@ -88,6 +91,10 @@ def config_params(): f1_threshold_classification = None # This threshold is used to consider models with an evaluation f1 scores bigger than it. The selected model weights undergo a weights ensembling. And avreage ensembled model will be written to output. classification_classes_name = None # Dictionary of classification classes names. backbone_type = None # As backbone we have 2 types of backbones. A vision transformer alongside a CNN and we call it "transformer" and only CNN called "nontransformer" + + dir_img_bin = None + number_of_backgrounds_per_image = 1 + dir_rgb_backgrounds = None @ex.automain @@ -95,15 +102,20 @@ def run(_config, n_classes, n_epochs, input_height, input_width, weight_decay, weighted_loss, index_start, dir_of_start_model, is_loss_soft_dice, n_batch, patches, augmentation, flip_aug, - blur_aug, padding_white, padding_black, scaling, degrading, - brightening, binarization, rgb_background, blur_k, scales, degrade_scales, + blur_aug, padding_white, padding_black, scaling, degrading,channels_shuffling, + brightening, binarization, adding_rgb_background, add_red_textlines, blur_k, scales, degrade_scales,shuffle_indexes, brightness, dir_train, data_is_provided, scaling_bluring, scaling_brightness, scaling_binarization, rotation, rotation_not_90, thetha, scaling_flip, continue_training, transformer_projection_dim, transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_cnn_first, transformer_patchsize_x, transformer_patchsize_y, transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output, - pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name): + pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds): + + if dir_rgb_backgrounds: + list_all_possible_background_images = os.listdir(dir_rgb_backgrounds) + else: + list_all_possible_background_images = None if task == "segmentation" or task == "enhancement" or task == "binarization": if data_is_provided: @@ -163,18 +175,18 @@ def run(_config, n_classes, n_epochs, input_height, # writing patches into a sub-folder in order to be flowed from directory. provide_patches(imgs_list, segs_list, dir_img, dir_seg, dir_flow_train_imgs, dir_flow_train_labels, input_height, input_width, blur_k, - blur_aug, padding_white, padding_black, flip_aug, binarization, rgb_background, + blur_aug, padding_white, padding_black, flip_aug, binarization, adding_rgb_background,add_red_textlines, channels_shuffling, scaling, degrading, brightening, scales, degrade_scales, brightness, - flip_index, scaling_bluring, scaling_brightness, scaling_binarization, + flip_index,shuffle_indexes, scaling_bluring, scaling_brightness, scaling_binarization, rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=augmentation, - patches=patches) + patches=patches, dir_img_bin=dir_img_bin,number_of_backgrounds_per_image=number_of_backgrounds_per_image,list_all_possible_background_images=list_all_possible_background_images, dir_rgb_backgrounds=dir_rgb_backgrounds) provide_patches(imgs_list_test, segs_list_test, dir_img_val, dir_seg_val, dir_flow_eval_imgs, dir_flow_eval_labels, input_height, input_width, - blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization, rgb_background, + blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization, adding_rgb_background, add_red_textlines, channels_shuffling, scaling, degrading, brightening, scales, degrade_scales, brightness, - flip_index, scaling_bluring, scaling_brightness, scaling_binarization, - rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=False, patches=patches) + flip_index, shuffle_indexes, scaling_bluring, scaling_brightness, scaling_binarization, + rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=False, patches=patches,dir_img_bin=dir_img_bin,number_of_backgrounds_per_image=number_of_backgrounds_per_image,list_all_possible_background_images=list_all_possible_background_images, dir_rgb_backgrounds=dir_rgb_backgrounds) if weighted_loss: weights = np.zeros(n_classes) diff --git a/train/utils.py b/train/utils.py index cf7a65c..20fda29 100644 --- a/train/utils.py +++ b/train/utils.py @@ -51,6 +51,16 @@ def return_binary_image_with_given_rgb_background_red_textlines(img_bin, img_rgb return img_final +def return_image_with_red_elements(img, img_bin): + img_final = np.copy(img) + + img_final[:,:,0][img_bin[:,:,0]==0] = 0 + img_final[:,:,1][img_bin[:,:,0]==0] = 0 + img_final[:,:,2][img_bin[:,:,0]==0] = 255 + return img_final + + + def scale_image_for_no_patch(img, label, scale): h_n = int(img.shape[0]*scale) w_n = int(img.shape[1]*scale) @@ -631,10 +641,10 @@ def get_patches_num_scale_new(dir_img_f, dir_seg_f, img, label, height, width, i def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow_train_imgs, dir_flow_train_labels, input_height, input_width, blur_k, blur_aug, - padding_white, padding_black, flip_aug, binarization, scaling, degrading, - brightening, scales, degrade_scales, brightness, flip_index, + padding_white, padding_black, flip_aug, binarization, adding_rgb_background, add_red_textlines, channels_shuffling, scaling, degrading, + brightening, scales, degrade_scales, brightness, flip_index, shuffle_indexes, scaling_bluring, scaling_brightness, scaling_binarization, rotation, - rotation_not_90, thetha, scaling_flip, task, augmentation=False, patches=False): + rotation_not_90, thetha, scaling_flip, task, augmentation=False, patches=False, dir_img_bin=None,number_of_backgrounds_per_image=None,list_all_possible_background_images=None, dir_rgb_backgrounds=None): indexer = 0 for im, seg_i in tqdm(zip(imgs_list_train, segs_list_train)): @@ -724,17 +734,29 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(label_scaled, input_height, input_width)) indexer += 1 - if rgb_color_background: + if adding_rgb_background: img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') for i_n in range(number_of_backgrounds_per_image): background_image_chosen_name = random.choice(list_all_possible_background_images) img_rgb_background_chosen = cv2.imread(dir_rgb_backgrounds + '/' + background_image_chosen_name) - img_with_overlayed_background = return_binary_image_with_given_rgb_background(img_bin_corr, img_rgb_background) + img_with_overlayed_background = return_binary_image_with_given_rgb_background(img_bin_corr, img_rgb_background_chosen) cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(img_with_overlayed_background, input_height, input_width)) cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) + indexer += 1 + + if add_red_textlines: + img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') + img_red_context = return_image_with_red_elements(cv2.imread(dir_img + '/'+im), img_bin_corr) + + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(img_red_context, input_height, input_width)) + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', + resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) + + indexer += 1 + From 61cdd2acb85e65ee023807ad885f1724e476596d Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Thu, 22 Aug 2024 21:58:09 +0200 Subject: [PATCH 089/123] using prepared binarized images in the case of augmentation --- train/utils.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/train/utils.py b/train/utils.py index 20fda29..84af85e 100644 --- a/train/utils.py +++ b/train/utils.py @@ -690,8 +690,15 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow pass if binarization: - cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', - resize_image(otsu_copy(cv2.imread(dir_img + '/' + im)), input_height, input_width)) + + if dir_img_bin: + img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') + + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', + resize_image(img_bin_corr, input_height, input_width)) + else: + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', + resize_image(otsu_copy(cv2.imread(dir_img + '/' + im)), input_height, input_width)) cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) From 5bbd0980b2a1ff3b5aa536353c21241539f6cf7b Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Wed, 28 Aug 2024 00:04:19 +0200 Subject: [PATCH 090/123] early dilation for textline artificial class --- train/gt_gen_utils.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/train/gt_gen_utils.py b/train/gt_gen_utils.py index 13010bf..dd4091f 100644 --- a/train/gt_gen_utils.py +++ b/train/gt_gen_utils.py @@ -88,12 +88,15 @@ def return_contours_of_interested_region(region_pre_p, pixel, min_area=0.0002): contours_imgs = filter_contours_area_of_image_tables(thresh, contours_imgs, hierarchy, max_area=1, min_area=min_area) return contours_imgs -def update_region_contours(co_text, img_boundary, erosion_rate, dilation_rate, y_len, x_len): +def update_region_contours(co_text, img_boundary, erosion_rate, dilation_rate, y_len, x_len, dilation_early=None): co_text_eroded = [] for con in co_text: img_boundary_in = np.zeros( (y_len,x_len) ) img_boundary_in = cv2.fillPoly(img_boundary_in, pts=[con], color=(1, 1, 1)) + if dilation_early: + img_boundary_in = cv2.dilate(img_boundary_in[:,:], KERNEL, iterations=dilation_early) + #img_boundary_in = cv2.erode(img_boundary_in[:,:], KERNEL, iterations=7)#asiatica if erosion_rate > 0: img_boundary_in = cv2.erode(img_boundary_in[:,:], KERNEL, iterations=erosion_rate) @@ -258,22 +261,25 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ if "artificial_class_label" in keys: img_boundary = np.zeros((y_len, x_len)) - erosion_rate = 1 + erosion_rate = 0#1 dilation_rate = 3 - co_use_case, img_boundary = update_region_contours(co_use_case, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + dilation_early = 2 + co_use_case, img_boundary = update_region_contours(co_use_case, img_boundary, erosion_rate, dilation_rate, y_len, x_len, dilation_early=dilation_early ) img = np.zeros((y_len, x_len, 3)) if output_type == '2d': img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1)) if "artificial_class_label" in keys: - img_poly[:,:][img_boundary[:,:]==1] = artificial_class_label + img_mask = np.copy(img_poly) + img_poly[:,:][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=1)] = artificial_class_label elif output_type == '3d': img_poly = cv2.fillPoly(img, pts=co_use_case, color=textline_rgb_color) if "artificial_class_label" in keys: - img_poly[:,:,0][img_boundary[:,:]==1] = artificial_class_rgb_color[0] - img_poly[:,:,1][img_boundary[:,:]==1] = artificial_class_rgb_color[1] - img_poly[:,:,2][img_boundary[:,:]==1] = artificial_class_rgb_color[2] + img_mask = np.copy(img_poly) + img_poly[:,:,0][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[0] + img_poly[:,:,1][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[1] + img_poly[:,:,2][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[2] if printspace and config_params['use_case']!='printspace': From a57a31673d78741c5679aac66e06991e46fcec73 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Wed, 28 Aug 2024 02:09:27 +0200 Subject: [PATCH 091/123] adding foreground rgb to augmentation --- train/config_params.json | 10 ++++++---- train/train.py | 19 +++++++++++++------ train/utils.py | 40 ++++++++++++++++++++++++++++++++++++++-- 3 files changed, 57 insertions(+), 12 deletions(-) diff --git a/train/config_params.json b/train/config_params.json index e5f652d..1db8026 100644 --- a/train/config_params.json +++ b/train/config_params.json @@ -13,13 +13,14 @@ "augmentation" : true, "flip_aug" : false, "blur_aug" : false, - "scaling" : true, + "scaling" : false, "adding_rgb_background": true, - "add_red_textlines": true, - "channels_shuffling": true, + "adding_rgb_foreground": true, + "add_red_textlines": false, + "channels_shuffling": false, "degrading": false, "brightening": false, - "binarization" : false, + "binarization" : true, "scaling_bluring" : false, "scaling_binarization" : false, "scaling_flip" : false, @@ -51,6 +52,7 @@ "dir_eval": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/eval_new", "dir_output": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/output_new", "dir_rgb_backgrounds": "/home/vahid/Documents/1_2_test_eynollah/set_rgb_background", + "dir_rgb_foregrounds": "/home/vahid/Documents/1_2_test_eynollah/out_set_rgb_foreground", "dir_img_bin": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/train_new/images_bin" } diff --git a/train/train.py b/train/train.py index 5dfad07..848ff6a 100644 --- a/train/train.py +++ b/train/train.py @@ -54,6 +54,7 @@ def config_params(): brightening = False # If true, brightening will be applied to the image. The amount of brightening is defined with "brightness" in config_params.json. binarization = False # If true, Otsu thresholding will be applied to augment the input with binarized images. adding_rgb_background = False + adding_rgb_foreground = False add_red_textlines = False channels_shuffling = False dir_train = None # Directory of training dataset with subdirectories having the names "images" and "labels". @@ -95,6 +96,7 @@ def config_params(): dir_img_bin = None number_of_backgrounds_per_image = 1 dir_rgb_backgrounds = None + dir_rgb_foregrounds = None @ex.automain @@ -103,20 +105,25 @@ def run(_config, n_classes, n_epochs, input_height, index_start, dir_of_start_model, is_loss_soft_dice, n_batch, patches, augmentation, flip_aug, blur_aug, padding_white, padding_black, scaling, degrading,channels_shuffling, - brightening, binarization, adding_rgb_background, add_red_textlines, blur_k, scales, degrade_scales,shuffle_indexes, + brightening, binarization, adding_rgb_background, adding_rgb_foreground, add_red_textlines, blur_k, scales, degrade_scales,shuffle_indexes, brightness, dir_train, data_is_provided, scaling_bluring, scaling_brightness, scaling_binarization, rotation, rotation_not_90, thetha, scaling_flip, continue_training, transformer_projection_dim, transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_cnn_first, transformer_patchsize_x, transformer_patchsize_y, transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output, - pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds): + pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds): if dir_rgb_backgrounds: list_all_possible_background_images = os.listdir(dir_rgb_backgrounds) else: list_all_possible_background_images = None + if dir_rgb_foregrounds: + list_all_possible_foreground_rgbs = os.listdir(dir_rgb_foregrounds) + else: + list_all_possible_foreground_rgbs = None + if task == "segmentation" or task == "enhancement" or task == "binarization": if data_is_provided: dir_train_flowing = os.path.join(dir_output, 'train') @@ -175,18 +182,18 @@ def run(_config, n_classes, n_epochs, input_height, # writing patches into a sub-folder in order to be flowed from directory. provide_patches(imgs_list, segs_list, dir_img, dir_seg, dir_flow_train_imgs, dir_flow_train_labels, input_height, input_width, blur_k, - blur_aug, padding_white, padding_black, flip_aug, binarization, adding_rgb_background,add_red_textlines, channels_shuffling, + blur_aug, padding_white, padding_black, flip_aug, binarization, adding_rgb_background,adding_rgb_foreground, add_red_textlines, channels_shuffling, scaling, degrading, brightening, scales, degrade_scales, brightness, flip_index,shuffle_indexes, scaling_bluring, scaling_brightness, scaling_binarization, rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=augmentation, - patches=patches, dir_img_bin=dir_img_bin,number_of_backgrounds_per_image=number_of_backgrounds_per_image,list_all_possible_background_images=list_all_possible_background_images, dir_rgb_backgrounds=dir_rgb_backgrounds) + patches=patches, dir_img_bin=dir_img_bin,number_of_backgrounds_per_image=number_of_backgrounds_per_image,list_all_possible_background_images=list_all_possible_background_images, dir_rgb_backgrounds=dir_rgb_backgrounds, dir_rgb_foregrounds=dir_rgb_foregrounds,list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs) provide_patches(imgs_list_test, segs_list_test, dir_img_val, dir_seg_val, dir_flow_eval_imgs, dir_flow_eval_labels, input_height, input_width, - blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization, adding_rgb_background, add_red_textlines, channels_shuffling, + blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization, adding_rgb_background, adding_rgb_foreground, add_red_textlines, channels_shuffling, scaling, degrading, brightening, scales, degrade_scales, brightness, flip_index, shuffle_indexes, scaling_bluring, scaling_brightness, scaling_binarization, - rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=False, patches=patches,dir_img_bin=dir_img_bin,number_of_backgrounds_per_image=number_of_backgrounds_per_image,list_all_possible_background_images=list_all_possible_background_images, dir_rgb_backgrounds=dir_rgb_backgrounds) + rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=False, patches=patches,dir_img_bin=dir_img_bin,number_of_backgrounds_per_image=number_of_backgrounds_per_image,list_all_possible_background_images=list_all_possible_background_images, dir_rgb_backgrounds=dir_rgb_backgrounds,dir_rgb_foregrounds=dir_rgb_foregrounds,list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs ) if weighted_loss: weights = np.zeros(n_classes) diff --git a/train/utils.py b/train/utils.py index 84af85e..d38e798 100644 --- a/train/utils.py +++ b/train/utils.py @@ -40,6 +40,25 @@ def return_binary_image_with_given_rgb_background(img_bin, img_rgb_background): return img_final +def return_binary_image_with_given_rgb_background_and_given_foreground_rgb(img_bin, img_rgb_background, rgb_foreground): + img_rgb_background = resize_image(img_rgb_background ,img_bin.shape[0], img_bin.shape[1]) + + img_final = np.copy(img_bin) + img_foreground = np.zeros(img_bin.shape) + + + img_foreground[:,:,0][img_bin[:,:,0] == 0] = rgb_foreground[0] + img_foreground[:,:,1][img_bin[:,:,0] == 0] = rgb_foreground[1] + img_foreground[:,:,2][img_bin[:,:,0] == 0] = rgb_foreground[2] + + + img_final[:,:,0][img_bin[:,:,0] != 0] = img_rgb_background[:,:,0][img_bin[:,:,0] != 0] + img_final[:,:,1][img_bin[:,:,1] != 0] = img_rgb_background[:,:,1][img_bin[:,:,1] != 0] + img_final[:,:,2][img_bin[:,:,2] != 0] = img_rgb_background[:,:,2][img_bin[:,:,2] != 0] + + img_final = img_final + img_foreground + return img_final + def return_binary_image_with_given_rgb_background_red_textlines(img_bin, img_rgb_background, img_color): img_rgb_background = resize_image(img_rgb_background ,img_bin.shape[0], img_bin.shape[1]) @@ -641,10 +660,10 @@ def get_patches_num_scale_new(dir_img_f, dir_seg_f, img, label, height, width, i def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow_train_imgs, dir_flow_train_labels, input_height, input_width, blur_k, blur_aug, - padding_white, padding_black, flip_aug, binarization, adding_rgb_background, add_red_textlines, channels_shuffling, scaling, degrading, + padding_white, padding_black, flip_aug, binarization, adding_rgb_background, adding_rgb_foreground, add_red_textlines, channels_shuffling, scaling, degrading, brightening, scales, degrade_scales, brightness, flip_index, shuffle_indexes, scaling_bluring, scaling_brightness, scaling_binarization, rotation, - rotation_not_90, thetha, scaling_flip, task, augmentation=False, patches=False, dir_img_bin=None,number_of_backgrounds_per_image=None,list_all_possible_background_images=None, dir_rgb_backgrounds=None): + rotation_not_90, thetha, scaling_flip, task, augmentation=False, patches=False, dir_img_bin=None,number_of_backgrounds_per_image=None,list_all_possible_background_images=None, dir_rgb_backgrounds=None, dir_rgb_foregrounds=None, list_all_possible_foreground_rgbs=None): indexer = 0 for im, seg_i in tqdm(zip(imgs_list_train, segs_list_train)): @@ -754,6 +773,23 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow indexer += 1 + if adding_rgb_foreground: + img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') + for i_n in range(number_of_backgrounds_per_image): + background_image_chosen_name = random.choice(list_all_possible_background_images) + foreground_rgb_chosen_name = random.choice(list_all_possible_foreground_rgbs) + + img_rgb_background_chosen = cv2.imread(dir_rgb_backgrounds + '/' + background_image_chosen_name) + foreground_rgb_chosen = np.load(dir_rgb_foregrounds + '/' + foreground_rgb_chosen_name) + + img_with_overlayed_background = return_binary_image_with_given_rgb_background_and_given_foreground_rgb(img_bin_corr, img_rgb_background_chosen, foreground_rgb_chosen) + + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(img_with_overlayed_background, input_height, input_width)) + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', + resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) + + indexer += 1 + if add_red_textlines: img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') img_red_context = return_image_with_red_elements(cv2.imread(dir_img + '/'+im), img_bin_corr) From e3da4944704d9d4af22a008addc1df8183a6ef44 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Wed, 28 Aug 2024 17:34:06 +0200 Subject: [PATCH 092/123] fixing artificial class bug --- train/gt_gen_utils.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/train/gt_gen_utils.py b/train/gt_gen_utils.py index dd4091f..5784e14 100644 --- a/train/gt_gen_utils.py +++ b/train/gt_gen_utils.py @@ -8,6 +8,7 @@ from tqdm import tqdm import cv2 from shapely import geometry from pathlib import Path +import matplotlib.pyplot as plt KERNEL = np.ones((5, 5), np.uint8) @@ -83,9 +84,13 @@ def return_contours_of_interested_region(region_pre_p, pixel, min_area=0.0002): ret, thresh = cv2.threshold(imgray, 0, 255, 0) contours_imgs, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + + #print(len(contours_imgs), hierarchy) contours_imgs = return_parent_contours(contours_imgs, hierarchy) - contours_imgs = filter_contours_area_of_image_tables(thresh, contours_imgs, hierarchy, max_area=1, min_area=min_area) + + #print(len(contours_imgs), "iki") + #contours_imgs = filter_contours_area_of_image_tables(thresh, contours_imgs, hierarchy, max_area=1, min_area=min_area) return contours_imgs def update_region_contours(co_text, img_boundary, erosion_rate, dilation_rate, y_len, x_len, dilation_early=None): @@ -103,12 +108,15 @@ def update_region_contours(co_text, img_boundary, erosion_rate, dilation_rate, y pixel = 1 min_size = 0 + + img_boundary_in = img_boundary_in.astype("uint8") + con_eroded = return_contours_of_interested_region(img_boundary_in,pixel, min_size ) - try: - co_text_eroded.append(con_eroded[0]) - except: - co_text_eroded.append(con) + #try: + co_text_eroded.append(con_eroded[0]) + #except: + #co_text_eroded.append(con) img_boundary_in_dilated = cv2.dilate(img_boundary_in[:,:], KERNEL, iterations=dilation_rate) @@ -262,8 +270,8 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ if "artificial_class_label" in keys: img_boundary = np.zeros((y_len, x_len)) erosion_rate = 0#1 - dilation_rate = 3 - dilation_early = 2 + dilation_rate = 2 + dilation_early = 1 co_use_case, img_boundary = update_region_contours(co_use_case, img_boundary, erosion_rate, dilation_rate, y_len, x_len, dilation_early=dilation_early ) From 3f354e1c342a36d52883c61bacebcddf43a31c54 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Fri, 30 Aug 2024 15:30:18 +0200 Subject: [PATCH 093/123] new augmentations for patchwise training --- train/utils.py | 64 ++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 60 insertions(+), 4 deletions(-) diff --git a/train/utils.py b/train/utils.py index d38e798..3d42b64 100644 --- a/train/utils.py +++ b/train/utils.py @@ -823,6 +823,53 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow img_max_rotated, label_max_rotated, input_height, input_width, indexer=indexer) + + if channels_shuffling: + for shuffle_index in shuffle_indexes: + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + return_shuffled_channels(cv2.imread(dir_img + '/' + im), shuffle_index), + cv2.imread(dir_of_label_file), + input_height, input_width, indexer=indexer) + + if adding_rgb_background: + img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') + for i_n in range(number_of_backgrounds_per_image): + background_image_chosen_name = random.choice(list_all_possible_background_images) + img_rgb_background_chosen = cv2.imread(dir_rgb_backgrounds + '/' + background_image_chosen_name) + img_with_overlayed_background = return_binary_image_with_given_rgb_background(img_bin_corr, img_rgb_background_chosen) + + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + img_with_overlayed_background, + cv2.imread(dir_of_label_file), + input_height, input_width, indexer=indexer) + + + if adding_rgb_foreground: + img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') + for i_n in range(number_of_backgrounds_per_image): + background_image_chosen_name = random.choice(list_all_possible_background_images) + foreground_rgb_chosen_name = random.choice(list_all_possible_foreground_rgbs) + + img_rgb_background_chosen = cv2.imread(dir_rgb_backgrounds + '/' + background_image_chosen_name) + foreground_rgb_chosen = np.load(dir_rgb_foregrounds + '/' + foreground_rgb_chosen_name) + + img_with_overlayed_background = return_binary_image_with_given_rgb_background_and_given_foreground_rgb(img_bin_corr, img_rgb_background_chosen, foreground_rgb_chosen) + + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + img_with_overlayed_background, + cv2.imread(dir_of_label_file), + input_height, input_width, indexer=indexer) + + + if add_red_textlines: + img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') + img_red_context = return_image_with_red_elements(cv2.imread(dir_img + '/'+im), img_bin_corr) + + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + img_red_context, + cv2.imread(dir_of_label_file), + input_height, input_width, indexer=indexer) + if flip_aug: for f_i in flip_index: indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, @@ -871,10 +918,19 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow input_height, input_width, indexer=indexer) if binarization: - indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, - otsu_copy(cv2.imread(dir_img + '/' + im)), - cv2.imread(dir_of_label_file), - input_height, input_width, indexer=indexer) + if dir_img_bin: + img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') + + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + img_bin_corr, + cv2.imread(dir_of_label_file), + input_height, input_width, indexer=indexer) + + else: + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + otsu_copy(cv2.imread(dir_img + '/' + im)), + cv2.imread(dir_of_label_file), + input_height, input_width, indexer=indexer) if scaling_brightness: for sc_ind in scales: From a524f8b1a7e5e68219cdcb12e239bc6ae8a1391c Mon Sep 17 00:00:00 2001 From: johnlockejrr Date: Sat, 19 Oct 2024 13:21:29 -0700 Subject: [PATCH 094/123] Update inference.py to check if save_layout was passed as argument otherwise can give an cv2 error --- train/inference.py | 1 + 1 file changed, 1 insertion(+) diff --git a/train/inference.py b/train/inference.py index 8d0a572..89d32de 100644 --- a/train/inference.py +++ b/train/inference.py @@ -566,6 +566,7 @@ class sbb_predict: img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task) if self.save: cv2.imwrite(self.save,img_seg_overlayed) + if self.save_layout: cv2.imwrite(self.save_layout, only_layout) if self.ground_truth: From f09eed1197d3f4d6cb4672fec48f73f50a1eee6b Mon Sep 17 00:00:00 2001 From: johnlockejrr Date: Sat, 19 Oct 2024 13:25:50 -0700 Subject: [PATCH 095/123] Changed deprecated `lr` to `learning_rate` and `model.fit_generator` to `model.fit` --- train/train.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/train/train.py b/train/train.py index 848ff6a..4cc3cbb 100644 --- a/train/train.py +++ b/train/train.py @@ -277,16 +277,16 @@ def run(_config, n_classes, n_epochs, input_height, if (task == "segmentation" or task == "binarization"): if not is_loss_soft_dice and not weighted_loss: model.compile(loss='categorical_crossentropy', - optimizer=Adam(lr=learning_rate), metrics=['accuracy']) + optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy']) if is_loss_soft_dice: model.compile(loss=soft_dice_loss, - optimizer=Adam(lr=learning_rate), metrics=['accuracy']) + optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy']) if weighted_loss: model.compile(loss=weighted_categorical_crossentropy(weights), - optimizer=Adam(lr=learning_rate), metrics=['accuracy']) + optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy']) elif task == "enhancement": model.compile(loss='mean_squared_error', - optimizer=Adam(lr=learning_rate), metrics=['accuracy']) + optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy']) # generating train and evaluation data @@ -299,7 +299,7 @@ def run(_config, n_classes, n_epochs, input_height, ##score_best=[] ##score_best.append(0) for i in tqdm(range(index_start, n_epochs + index_start)): - model.fit_generator( + model.fit( train_gen, steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1, validation_data=val_gen, @@ -384,7 +384,7 @@ def run(_config, n_classes, n_epochs, input_height, #f1score_tot = [0] indexer_start = 0 - opt = SGD(lr=0.01, momentum=0.9) + opt = SGD(learning_rate=0.01, momentum=0.9) opt_adam = tf.keras.optimizers.Adam(learning_rate=0.0001) model.compile(loss="binary_crossentropy", optimizer = opt_adam,metrics=['accuracy']) From fd14e656aa38b17ca25224268d2e66634506b107 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Fri, 25 Oct 2024 14:01:39 +0200 Subject: [PATCH 096/123] early_erosion is added --- train/gt_gen_utils.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/train/gt_gen_utils.py b/train/gt_gen_utils.py index 5784e14..cabc7df 100644 --- a/train/gt_gen_utils.py +++ b/train/gt_gen_utils.py @@ -93,7 +93,7 @@ def return_contours_of_interested_region(region_pre_p, pixel, min_area=0.0002): #contours_imgs = filter_contours_area_of_image_tables(thresh, contours_imgs, hierarchy, max_area=1, min_area=min_area) return contours_imgs -def update_region_contours(co_text, img_boundary, erosion_rate, dilation_rate, y_len, x_len, dilation_early=None): +def update_region_contours(co_text, img_boundary, erosion_rate, dilation_rate, y_len, x_len, dilation_early=None, erosion_early=None): co_text_eroded = [] for con in co_text: img_boundary_in = np.zeros( (y_len,x_len) ) @@ -101,6 +101,9 @@ def update_region_contours(co_text, img_boundary, erosion_rate, dilation_rate, y if dilation_early: img_boundary_in = cv2.dilate(img_boundary_in[:,:], KERNEL, iterations=dilation_early) + + if erosion_early: + img_boundary_in = cv2.erode(img_boundary_in[:,:], KERNEL, iterations=erosion_early) #img_boundary_in = cv2.erode(img_boundary_in[:,:], KERNEL, iterations=7)#asiatica if erosion_rate > 0: @@ -137,6 +140,7 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ ls_org_imgs_stem = [item.split('.')[0] for item in ls_org_imgs] for index in tqdm(range(len(gt_list))): #try: + print(gt_list[index]) tree1 = ET.parse(dir_in+'/'+gt_list[index], parser = ET.XMLParser(encoding = 'iso-8859-5')) root1=tree1.getroot() alltags=[elem.tag for elem in root1.iter()] @@ -271,8 +275,9 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ img_boundary = np.zeros((y_len, x_len)) erosion_rate = 0#1 dilation_rate = 2 - dilation_early = 1 - co_use_case, img_boundary = update_region_contours(co_use_case, img_boundary, erosion_rate, dilation_rate, y_len, x_len, dilation_early=dilation_early ) + dilation_early = 0 + erosion_early = 2 + co_use_case, img_boundary = update_region_contours(co_use_case, img_boundary, erosion_rate, dilation_rate, y_len, x_len, dilation_early=dilation_early, erosion_early=erosion_early) img = np.zeros((y_len, x_len, 3)) @@ -280,7 +285,8 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1)) if "artificial_class_label" in keys: img_mask = np.copy(img_poly) - img_poly[:,:][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=1)] = artificial_class_label + ##img_poly[:,:][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=1)] = artificial_class_label + img_poly[:,:][img_boundary[:,:]==1] = artificial_class_label elif output_type == '3d': img_poly = cv2.fillPoly(img, pts=co_use_case, color=textline_rgb_color) if "artificial_class_label" in keys: From 7b4d14b19f536614545b209bf3834b6b84a67d1d Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Tue, 29 Oct 2024 17:06:22 +0100 Subject: [PATCH 097/123] addinh shifting augmentation --- train/train.py | 7 ++++--- train/utils.py | 55 +++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/train/train.py b/train/train.py index 848ff6a..7e3e390 100644 --- a/train/train.py +++ b/train/train.py @@ -50,6 +50,7 @@ def config_params(): padding_white = False # If true, white padding will be applied to the image. padding_black = False # If true, black padding will be applied to the image. scaling = False # If true, scaling will be applied to the image. The amount of scaling is defined with "scales" in config_params.json. + shifting = False degrading = False # If true, degrading will be applied to the image. The amount of degrading is defined with "degrade_scales" in config_params.json. brightening = False # If true, brightening will be applied to the image. The amount of brightening is defined with "brightness" in config_params.json. binarization = False # If true, Otsu thresholding will be applied to augment the input with binarized images. @@ -104,7 +105,7 @@ def run(_config, n_classes, n_epochs, input_height, input_width, weight_decay, weighted_loss, index_start, dir_of_start_model, is_loss_soft_dice, n_batch, patches, augmentation, flip_aug, - blur_aug, padding_white, padding_black, scaling, degrading,channels_shuffling, + blur_aug, padding_white, padding_black, scaling, shifting, degrading,channels_shuffling, brightening, binarization, adding_rgb_background, adding_rgb_foreground, add_red_textlines, blur_k, scales, degrade_scales,shuffle_indexes, brightness, dir_train, data_is_provided, scaling_bluring, scaling_brightness, scaling_binarization, rotation, rotation_not_90, @@ -183,7 +184,7 @@ def run(_config, n_classes, n_epochs, input_height, provide_patches(imgs_list, segs_list, dir_img, dir_seg, dir_flow_train_imgs, dir_flow_train_labels, input_height, input_width, blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization, adding_rgb_background,adding_rgb_foreground, add_red_textlines, channels_shuffling, - scaling, degrading, brightening, scales, degrade_scales, brightness, + scaling, shifting, degrading, brightening, scales, degrade_scales, brightness, flip_index,shuffle_indexes, scaling_bluring, scaling_brightness, scaling_binarization, rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=augmentation, patches=patches, dir_img_bin=dir_img_bin,number_of_backgrounds_per_image=number_of_backgrounds_per_image,list_all_possible_background_images=list_all_possible_background_images, dir_rgb_backgrounds=dir_rgb_backgrounds, dir_rgb_foregrounds=dir_rgb_foregrounds,list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs) @@ -191,7 +192,7 @@ def run(_config, n_classes, n_epochs, input_height, provide_patches(imgs_list_test, segs_list_test, dir_img_val, dir_seg_val, dir_flow_eval_imgs, dir_flow_eval_labels, input_height, input_width, blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization, adding_rgb_background, adding_rgb_foreground, add_red_textlines, channels_shuffling, - scaling, degrading, brightening, scales, degrade_scales, brightness, + scaling, shifting, degrading, brightening, scales, degrade_scales, brightness, flip_index, shuffle_indexes, scaling_bluring, scaling_brightness, scaling_binarization, rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=False, patches=patches,dir_img_bin=dir_img_bin,number_of_backgrounds_per_image=number_of_backgrounds_per_image,list_all_possible_background_images=list_all_possible_background_images, dir_rgb_backgrounds=dir_rgb_backgrounds,dir_rgb_foregrounds=dir_rgb_foregrounds,list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs ) diff --git a/train/utils.py b/train/utils.py index 3d42b64..d7ddb99 100644 --- a/train/utils.py +++ b/train/utils.py @@ -78,7 +78,50 @@ def return_image_with_red_elements(img, img_bin): img_final[:,:,2][img_bin[:,:,0]==0] = 255 return img_final +def shift_image_and_label(img, label, type_shift): + h_n = int(img.shape[0]*1.06) + w_n = int(img.shape[1]*1.06) + + channel0_avg = int( np.mean(img[:,:,0]) ) + channel1_avg = int( np.mean(img[:,:,1]) ) + channel2_avg = int( np.mean(img[:,:,2]) ) + h_diff = abs( img.shape[0] - h_n ) + w_diff = abs( img.shape[1] - w_n ) + + h_start = int(h_diff / 2.) + w_start = int(w_diff / 2.) + + img_scaled_padded = np.zeros((h_n, w_n, 3)) + label_scaled_padded = np.zeros((h_n, w_n, 3)) + + img_scaled_padded[:,:,0] = channel0_avg + img_scaled_padded[:,:,1] = channel1_avg + img_scaled_padded[:,:,2] = channel2_avg + + img_scaled_padded[h_start:h_start+img.shape[0], w_start:w_start+img.shape[1],:] = img[:,:,:] + label_scaled_padded[h_start:h_start+img.shape[0], w_start:w_start+img.shape[1],:] = label[:,:,:] + + + if type_shift=="xpos": + img_dis = img_scaled_padded[h_start:h_start+img.shape[0],2*w_start:2*w_start+img.shape[1],:] + label_dis = label_scaled_padded[h_start:h_start+img.shape[0],2*w_start:2*w_start+img.shape[1],:] + elif type_shift=="xmin": + img_dis = img_scaled_padded[h_start:h_start+img.shape[0],:img.shape[1],:] + label_dis = label_scaled_padded[h_start:h_start+img.shape[0],:img.shape[1],:] + elif type_shift=="ypos": + img_dis = img_scaled_padded[2*h_start:2*h_start+img.shape[0],w_start:w_start+img.shape[1],:] + label_dis = label_scaled_padded[2*h_start:2*h_start+img.shape[0],w_start:w_start+img.shape[1],:] + elif type_shift=="ymin": + img_dis = img_scaled_padded[:img.shape[0],w_start:w_start+img.shape[1],:] + label_dis = label_scaled_padded[:img.shape[0],w_start:w_start+img.shape[1],:] + elif type_shift=="xypos": + img_dis = img_scaled_padded[2*h_start:2*h_start+img.shape[0],2*w_start:2*w_start+img.shape[1],:] + label_dis = label_scaled_padded[2*h_start:2*h_start+img.shape[0],2*w_start:2*w_start+img.shape[1],:] + elif type_shift=="xymin": + img_dis = img_scaled_padded[:img.shape[0],:img.shape[1],:] + label_dis = label_scaled_padded[:img.shape[0],:img.shape[1],:] + return img_dis, label_dis def scale_image_for_no_patch(img, label, scale): h_n = int(img.shape[0]*scale) @@ -660,7 +703,7 @@ def get_patches_num_scale_new(dir_img_f, dir_seg_f, img, label, height, width, i def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow_train_imgs, dir_flow_train_labels, input_height, input_width, blur_k, blur_aug, - padding_white, padding_black, flip_aug, binarization, adding_rgb_background, adding_rgb_foreground, add_red_textlines, channels_shuffling, scaling, degrading, + padding_white, padding_black, flip_aug, binarization, adding_rgb_background, adding_rgb_foreground, add_red_textlines, channels_shuffling, scaling, shifting, degrading, brightening, scales, degrade_scales, brightness, flip_index, shuffle_indexes, scaling_bluring, scaling_brightness, scaling_binarization, rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=False, patches=False, dir_img_bin=None,number_of_backgrounds_per_image=None,list_all_possible_background_images=None, dir_rgb_backgrounds=None, dir_rgb_foregrounds=None, list_all_possible_foreground_rgbs=None): @@ -759,6 +802,16 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(img_scaled, input_height, input_width)) cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(label_scaled, input_height, input_width)) indexer += 1 + if shifting: + shift_types = ['xpos', 'xmin', 'ypos', 'ymin', 'xypos', 'xymin'] + for st_ind in shift_types: + img_shifted, label_shifted = shift_image_and_label(cv2.imread(dir_img + '/'+im), + cv2.imread(dir_of_label_file), st_ind) + + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(img_shifted, input_height, input_width)) + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(label_shifted, input_height, input_width)) + indexer += 1 + if adding_rgb_background: img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') From 238ea3bd8ef59da890646c9b1581145b8d937d85 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Thu, 14 Nov 2024 16:26:19 +0100 Subject: [PATCH 098/123] update resizing in inference --- train/inference.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/train/inference.py b/train/inference.py index 8d0a572..2b12ff7 100644 --- a/train/inference.py +++ b/train/inference.py @@ -442,10 +442,11 @@ class sbb_predict: self.img_org = np.copy(img) if img.shape[0] < self.img_height: - img = cv2.resize(img, (img.shape[1], self.img_width), interpolation=cv2.INTER_NEAREST) + img = self.resize_image(img, self.img_height, img.shape[1]) if img.shape[1] < self.img_width: - img = cv2.resize(img, (self.img_height, img.shape[0]), interpolation=cv2.INTER_NEAREST) + img = self.resize_image(img, img.shape[0], self.img_width) + margin = int(0.1 * self.img_width) width_mid = self.img_width - 2 * margin height_mid = self.img_height - 2 * margin From e9b860b27513a255ec94892aec8b6a61e23d0b87 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Mon, 18 Nov 2024 16:34:53 +0100 Subject: [PATCH 099/123] artificial_class_label for table region --- train/gt_gen_utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/train/gt_gen_utils.py b/train/gt_gen_utils.py index cabc7df..95b8414 100644 --- a/train/gt_gen_utils.py +++ b/train/gt_gen_utils.py @@ -116,10 +116,10 @@ def update_region_contours(co_text, img_boundary, erosion_rate, dilation_rate, y con_eroded = return_contours_of_interested_region(img_boundary_in,pixel, min_size ) - #try: - co_text_eroded.append(con_eroded[0]) - #except: - #co_text_eroded.append(con) + try: + co_text_eroded.append(con_eroded[0]) + except: + co_text_eroded.append(con) img_boundary_in_dilated = cv2.dilate(img_boundary_in[:,:], KERNEL, iterations=dilation_rate) @@ -636,6 +636,10 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ erosion_rate = 0#2 dilation_rate = 2#4 co_text["footnote-continued"], img_boundary = update_region_contours(co_text["footnote-continued"], img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "tableregion" in elements_with_artificial_class: + erosion_rate = 0#2 + dilation_rate = 3#4 + co_table, img_boundary = update_region_contours(co_table, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) From 90a1b186f78a9ad5934c4d46d93e1c2bf20d6789 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Fri, 14 Mar 2025 17:20:33 +0100 Subject: [PATCH 100/123] this enables to visualize reading order of textregions provided in page-xml files --- train/generate_gt_for_training.py | 67 +++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/train/generate_gt_for_training.py b/train/generate_gt_for_training.py index cfcc151..9e0f45e 100644 --- a/train/generate_gt_for_training.py +++ b/train/generate_gt_for_training.py @@ -214,6 +214,73 @@ def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, i cv2.imwrite(os.path.join(dir_out_modal_image,final_f_name+'.png' ), input_multi_visual_modal) indexer = indexer+1 + + +@main.command() +@click.option( + "--dir_xml", + "-dx", + help="directory of GT page-xml files", + type=click.Path(exists=True, file_okay=False), +) + +@click.option( + "--dir_out", + "-do", + help="directory where plots will be written", + type=click.Path(exists=True, file_okay=False), +) + + +def visualize_reading_order(dir_xml, dir_out): + xml_files_ind = os.listdir(dir_xml) + + + indexer_start= 0#55166 + #min_area = 0.0001 + + for ind_xml in tqdm(xml_files_ind): + indexer = 0 + #print(ind_xml) + #print('########################') + xml_file = os.path.join(dir_xml,ind_xml ) + f_name = ind_xml.split('.')[0] + _, _, _, file_name, id_paragraph, id_header,co_text_paragraph,co_text_header,tot_region_ref,x_len, y_len,index_tot_regions,img_poly = read_xml(xml_file) + + id_all_text = id_paragraph + id_header + co_text_all = co_text_paragraph + co_text_header + + + cx_main, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, _ = find_new_features_of_contours(co_text_all) + + texts_corr_order_index = [int(index_tot_regions[tot_region_ref.index(i)]) for i in id_all_text ] + #texts_corr_order_index_int = [int(x) for x in texts_corr_order_index] + + + #cx_ordered = np.array(cx_main)[np.array(texts_corr_order_index)] + #cx_ordered = cx_ordered.astype(np.int32) + + cx_ordered = [int(val) for (_, val) in sorted(zip(texts_corr_order_index, cx_main), key=lambda x: \ + x[0], reverse=False)] + #cx_ordered = cx_ordered.astype(np.int32) + + cy_ordered = [int(val) for (_, val) in sorted(zip(texts_corr_order_index, cy_main), key=lambda x: \ + x[0], reverse=False)] + #cy_ordered = cy_ordered.astype(np.int32) + + + color = (0, 0, 255) + thickness = 20 + + img = np.zeros( (y_len,x_len,3) ) + img = cv2.fillPoly(img, pts =co_text_all, color=(255,0,0)) + for i in range(len(cx_ordered)-1): + start_point = (int(cx_ordered[i]), int(cy_ordered[i])) + end_point = (int(cx_ordered[i+1]), int(cy_ordered[i+1])) + img = cv2.arrowedLine(img, start_point, end_point, + color, thickness, tipLength = 0.03) + + cv2.imwrite(os.path.join(dir_out, f_name+'.png'), img) From 363c343b373d99170d795ff20520ba9e586b4ab1 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Mon, 17 Mar 2025 20:09:48 +0100 Subject: [PATCH 101/123] visualising reaidng order- Overlaying on image is provided --- train/generate_gt_for_training.py | 36 ++++++++++++++------- train/gt_gen_utils.py | 53 +++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 11 deletions(-) diff --git a/train/generate_gt_for_training.py b/train/generate_gt_for_training.py index 9e0f45e..9869bfa 100644 --- a/train/generate_gt_for_training.py +++ b/train/generate_gt_for_training.py @@ -231,8 +231,12 @@ def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, i type=click.Path(exists=True, file_okay=False), ) +@click.option( + "--dir_imgs", + "-dimg", + help="directory where the overlayed plots will be written", ) -def visualize_reading_order(dir_xml, dir_out): +def visualize_reading_order(dir_xml, dir_out, dir_imgs): xml_files_ind = os.listdir(dir_xml) @@ -271,16 +275,26 @@ def visualize_reading_order(dir_xml, dir_out): color = (0, 0, 255) thickness = 20 - - img = np.zeros( (y_len,x_len,3) ) - img = cv2.fillPoly(img, pts =co_text_all, color=(255,0,0)) - for i in range(len(cx_ordered)-1): - start_point = (int(cx_ordered[i]), int(cy_ordered[i])) - end_point = (int(cx_ordered[i+1]), int(cy_ordered[i+1])) - img = cv2.arrowedLine(img, start_point, end_point, - color, thickness, tipLength = 0.03) - - cv2.imwrite(os.path.join(dir_out, f_name+'.png'), img) + if dir_imgs: + layout = np.zeros( (y_len,x_len,3) ) + layout = cv2.fillPoly(layout, pts =co_text_all, color=(1,1,1)) + + img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name) + img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format)) + + overlayed = overlay_layout_on_image(layout, img, cx_ordered, cy_ordered, color, thickness) + cv2.imwrite(os.path.join(dir_out, f_name+'.png'), overlayed) + + else: + img = np.zeros( (y_len,x_len,3) ) + img = cv2.fillPoly(img, pts =co_text_all, color=(255,0,0)) + for i in range(len(cx_ordered)-1): + start_point = (int(cx_ordered[i]), int(cy_ordered[i])) + end_point = (int(cx_ordered[i+1]), int(cy_ordered[i+1])) + img = cv2.arrowedLine(img, start_point, end_point, + color, thickness, tipLength = 0.03) + + cv2.imwrite(os.path.join(dir_out, f_name+'.png'), img) diff --git a/train/gt_gen_utils.py b/train/gt_gen_utils.py index 95b8414..753abf2 100644 --- a/train/gt_gen_utils.py +++ b/train/gt_gen_utils.py @@ -1290,3 +1290,56 @@ def update_list_and_return_first_with_length_bigger_than_one(index_element_to_be else: early_list_bigger_than_one = -20 return list_inp, early_list_bigger_than_one + +def overlay_layout_on_image(prediction, img, cx_ordered, cy_ordered, color, thickness): + + unique_classes = np.unique(prediction[:,:,0]) + rgb_colors = {'0' : [255, 255, 255], + '1' : [255, 0, 0], + '2' : [0, 0, 255], + '3' : [255, 0, 125], + '4' : [125, 125, 125], + '5' : [125, 125, 0], + '6' : [0, 125, 255], + '7' : [0, 125, 0], + '8' : [125, 125, 125], + '9' : [0, 125, 255], + '10' : [125, 0, 125], + '11' : [0, 255, 0], + '12' : [255, 125, 0], + '13' : [0, 255, 255], + '14' : [255, 125, 125], + '15' : [255, 0, 255]} + + layout_only = np.zeros(prediction.shape) + + for unq_class in unique_classes: + rgb_class_unique = rgb_colors[str(int(unq_class))] + layout_only[:,:,0][prediction[:,:,0]==unq_class] = rgb_class_unique[0] + layout_only[:,:,1][prediction[:,:,0]==unq_class] = rgb_class_unique[1] + layout_only[:,:,2][prediction[:,:,0]==unq_class] = rgb_class_unique[2] + + + + #img = self.resize_image(img, layout_only.shape[0], layout_only.shape[1]) + + layout_only = layout_only.astype(np.int32) + + for i in range(len(cx_ordered)-1): + start_point = (int(cx_ordered[i]), int(cy_ordered[i])) + end_point = (int(cx_ordered[i+1]), int(cy_ordered[i+1])) + layout_only = cv2.arrowedLine(layout_only, start_point, end_point, + color, thickness, tipLength = 0.03) + + img = img.astype(np.int32) + + + + added_image = cv2.addWeighted(img,0.5,layout_only,0.1,0) + + return added_image + +def find_format_of_given_filename_in_dir(dir_imgs, f_name): + ls_imgs = os.listdir(dir_imgs) + file_interested = [ind for ind in ls_imgs if ind.startswith(f_name+'.')] + return file_interested[0] From 825b2634f96788cc3351f089d24b8a1c2e202194 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Wed, 16 Apr 2025 23:36:41 +0200 Subject: [PATCH 102/123] rotation augmentation is provided for machine based reading order --- train/train.py | 7 +++++-- train/utils.py | 23 ++++++++++++++++++++++- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/train/train.py b/train/train.py index 7e3e390..130c7f4 100644 --- a/train/train.py +++ b/train/train.py @@ -380,7 +380,10 @@ def run(_config, n_classes, n_epochs, input_height, dir_flow_train_labels = os.path.join(dir_train, 'labels') classes = os.listdir(dir_flow_train_labels) - num_rows =len(classes) + if augmentation: + num_rows = len(classes)*(len(thetha) + 1) + else: + num_rows = len(classes) #ls_test = os.listdir(dir_flow_train_labels) #f1score_tot = [0] @@ -390,7 +393,7 @@ def run(_config, n_classes, n_epochs, input_height, model.compile(loss="binary_crossentropy", optimizer = opt_adam,metrics=['accuracy']) for i in range(n_epochs): - history = model.fit(generate_arrays_from_folder_reading_order(dir_flow_train_labels, dir_flow_train_imgs, n_batch, input_height, input_width, n_classes), steps_per_epoch=num_rows / n_batch, verbose=1) + history = model.fit(generate_arrays_from_folder_reading_order(dir_flow_train_labels, dir_flow_train_imgs, n_batch, input_height, input_width, n_classes, thetha, augmentation), steps_per_epoch=num_rows / n_batch, verbose=1) model.save( os.path.join(dir_output,'model_'+str(i+indexer_start) )) with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp: diff --git a/train/utils.py b/train/utils.py index d7ddb99..50c21af 100644 --- a/train/utils.py +++ b/train/utils.py @@ -363,6 +363,11 @@ def rotation_not_90_func(img, label, thetha): return rotate_max_area(img, rotated, rotated_label, thetha) +def rotation_not_90_func_single_image(img, thetha): + rotated = imutils.rotate(img, thetha) + return rotate_max_area(img, rotated, thetha) + + def color_images(seg, n_classes): ann_u = range(n_classes) if len(np.shape(seg)) == 3: @@ -410,7 +415,7 @@ def IoU(Yi, y_predi): #print("Mean IoU: {:4.3f}".format(mIoU)) return mIoU -def generate_arrays_from_folder_reading_order(classes_file_dir, modal_dir, batchsize, height, width, n_classes): +def generate_arrays_from_folder_reading_order(classes_file_dir, modal_dir, batchsize, height, width, n_classes, thetha, augmentation=False): all_labels_files = os.listdir(classes_file_dir) ret_x= np.zeros((batchsize, height, width, 3))#.astype(np.int16) ret_y= np.zeros((batchsize, n_classes)).astype(np.int16) @@ -433,6 +438,22 @@ def generate_arrays_from_folder_reading_order(classes_file_dir, modal_dir, batch ret_x= np.zeros((batchsize, height, width, 3))#.astype(np.int16) ret_y= np.zeros((batchsize, n_classes)).astype(np.int16) batchcount = 0 + + if augmentation: + for thetha_i in thetha: + img_rot = rotation_not_90_func_single_image(img, thetha_i) + + ret_x[batchcount, :,:,0] = img_rot[:,:,0]/3.0 + ret_x[batchcount, :,:,2] = img_rot[:,:,2]/3.0 + ret_x[batchcount, :,:,1] = img_rot[:,:,1]/5.0 + + ret_y[batchcount, :] = label_class + batchcount+=1 + if batchcount>=batchsize: + yield (ret_x, ret_y) + ret_x= np.zeros((batchsize, height, width, 3))#.astype(np.int16) + ret_y= np.zeros((batchsize, n_classes)).astype(np.int16) + batchcount = 0 def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_classes, task='segmentation'): c = 0 From dd21a3b33a3adb1a8ba2c34e2144e01b2b094366 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Thu, 17 Apr 2025 00:05:59 +0200 Subject: [PATCH 103/123] updating:rotation augmentation is provided for machine based reading order --- train/utils.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/train/utils.py b/train/utils.py index 50c21af..485056b 100644 --- a/train/utils.py +++ b/train/utils.py @@ -356,6 +356,18 @@ def rotate_max_area(image, rotated, rotated_label, angle): x2 = x1 + int(wr) return rotated[y1:y2, x1:x2], rotated_label[y1:y2, x1:x2] +def rotate_max_area_single_image(image, rotated, angle): + """ image: cv2 image matrix object + angle: in degree + """ + wr, hr = rotatedRectWithMaxArea(image.shape[1], image.shape[0], + math.radians(angle)) + h, w, _ = rotated.shape + y1 = h // 2 - int(hr / 2) + y2 = y1 + int(hr) + x1 = w // 2 - int(wr / 2) + x2 = x1 + int(wr) + return rotated[y1:y2, x1:x2] def rotation_not_90_func(img, label, thetha): rotated = imutils.rotate(img, thetha) @@ -365,7 +377,7 @@ def rotation_not_90_func(img, label, thetha): def rotation_not_90_func_single_image(img, thetha): rotated = imutils.rotate(img, thetha) - return rotate_max_area(img, rotated, thetha) + return rotate_max_area_single_image(img, rotated, thetha) def color_images(seg, n_classes): From 4635dd219d5cfade1c038a371dceb78452a7fbf9 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Thu, 17 Apr 2025 00:12:30 +0200 Subject: [PATCH 104/123] updating:rotation augmentation is provided for machine based reading order --- train/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/train/utils.py b/train/utils.py index 485056b..8be6963 100644 --- a/train/utils.py +++ b/train/utils.py @@ -455,6 +455,8 @@ def generate_arrays_from_folder_reading_order(classes_file_dir, modal_dir, batch for thetha_i in thetha: img_rot = rotation_not_90_func_single_image(img, thetha_i) + img_rot = resize_image(img_rot, height, width) + ret_x[batchcount, :,:,0] = img_rot[:,:,0]/3.0 ret_x[batchcount, :,:,2] = img_rot[:,:,2]/3.0 ret_x[batchcount, :,:,1] = img_rot[:,:,1]/5.0 From 3b123b039c432145359f7b6a3b0d45c8669df791 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Sat, 3 May 2025 19:25:32 +0200 Subject: [PATCH 105/123] adding min_early parameter for generating training dataset for machine based reading order model --- train/generate_gt_for_training.py | 64 +++++++++++++++++++++++-------- train/gt_gen_utils.py | 13 ++++++- 2 files changed, 60 insertions(+), 17 deletions(-) diff --git a/train/generate_gt_for_training.py b/train/generate_gt_for_training.py index 9869bfa..77e9238 100644 --- a/train/generate_gt_for_training.py +++ b/train/generate_gt_for_training.py @@ -147,11 +147,20 @@ def image_enhancement(dir_imgs, dir_out_images, dir_out_labels, scales): help="min area size of regions considered for reading order training.", ) -def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, input_height, input_width, min_area_size): +@click.option( + "--min_area_early", + "-min_early", + help="If you have already generated a training dataset using a specific minimum area value and now wish to create a dataset with a smaller minimum area value, you can avoid regenerating the previous dataset by providing the earlier minimum area value. This will ensure that only the missing data is generated.", +) + +def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, input_height, input_width, min_area_size, min_area_early): xml_files_ind = os.listdir(dir_xml) input_height = int(input_height) input_width = int(input_width) min_area = float(min_area_size) + if min_area_early: + min_area_early = float(min_area_early) + indexer_start= 0#55166 max_area = 1 @@ -181,7 +190,8 @@ def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, i texts_corr_order_index_int = [int(x) for x in texts_corr_order_index] - co_text_all, texts_corr_order_index_int = filter_contours_area_of_image(img_poly, co_text_all, texts_corr_order_index_int, max_area, min_area) + co_text_all, texts_corr_order_index_int, regions_ar_less_than_early_min = filter_contours_area_of_image(img_poly, co_text_all, texts_corr_order_index_int, max_area, min_area, min_area_early) + arg_array = np.array(range(len(texts_corr_order_index_int))) @@ -195,25 +205,49 @@ def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, i labels_con[:,:,i] = img_label[:,:,0] + labels_con = resize_image(labels_con, input_height, input_width) + img_poly = resize_image(img_poly, input_height, input_width) + + for i in range(len(texts_corr_order_index_int)): for j in range(len(texts_corr_order_index_int)): if i!=j: - input_multi_visual_modal = np.zeros((input_height,input_width,3)).astype(np.int8) - final_f_name = f_name+'_'+str(indexer+indexer_start) - order_class_condition = texts_corr_order_index_int[i]-texts_corr_order_index_int[j] - if order_class_condition<0: - class_type = 1 + if regions_ar_less_than_early_min: + if regions_ar_less_than_early_min[i]==1: + input_multi_visual_modal = np.zeros((input_height,input_width,3)).astype(np.int8) + final_f_name = f_name+'_'+str(indexer+indexer_start) + order_class_condition = texts_corr_order_index_int[i]-texts_corr_order_index_int[j] + if order_class_condition<0: + class_type = 1 + else: + class_type = 0 + + input_multi_visual_modal[:,:,0] = labels_con[:,:,i] + input_multi_visual_modal[:,:,1] = img_poly[:,:,0] + input_multi_visual_modal[:,:,2] = labels_con[:,:,j] + + np.save(os.path.join(dir_out_classes,final_f_name+'_missed.npy' ), class_type) + + cv2.imwrite(os.path.join(dir_out_modal_image,final_f_name+'_missed.png' ), input_multi_visual_modal) + indexer = indexer+1 + else: - class_type = 0 + input_multi_visual_modal = np.zeros((input_height,input_width,3)).astype(np.int8) + final_f_name = f_name+'_'+str(indexer+indexer_start) + order_class_condition = texts_corr_order_index_int[i]-texts_corr_order_index_int[j] + if order_class_condition<0: + class_type = 1 + else: + class_type = 0 - input_multi_visual_modal[:,:,0] = resize_image(labels_con[:,:,i], input_height, input_width) - input_multi_visual_modal[:,:,1] = resize_image(img_poly[:,:,0], input_height, input_width) - input_multi_visual_modal[:,:,2] = resize_image(labels_con[:,:,j], input_height, input_width) + input_multi_visual_modal[:,:,0] = labels_con[:,:,i] + input_multi_visual_modal[:,:,1] = img_poly[:,:,0] + input_multi_visual_modal[:,:,2] = labels_con[:,:,j] - np.save(os.path.join(dir_out_classes,final_f_name+'.npy' ), class_type) - - cv2.imwrite(os.path.join(dir_out_modal_image,final_f_name+'.png' ), input_multi_visual_modal) - indexer = indexer+1 + np.save(os.path.join(dir_out_classes,final_f_name+'.npy' ), class_type) + + cv2.imwrite(os.path.join(dir_out_modal_image,final_f_name+'.png' ), input_multi_visual_modal) + indexer = indexer+1 @main.command() diff --git a/train/gt_gen_utils.py b/train/gt_gen_utils.py index 753abf2..10183d6 100644 --- a/train/gt_gen_utils.py +++ b/train/gt_gen_utils.py @@ -51,9 +51,10 @@ def filter_contours_area_of_image_tables(image, contours, hierarchy, max_area, m jv += 1 return found_polygons_early -def filter_contours_area_of_image(image, contours, order_index, max_area, min_area): +def filter_contours_area_of_image(image, contours, order_index, max_area, min_area, min_early): found_polygons_early = list() order_index_filtered = list() + regions_ar_less_than_early_min = list() #jv = 0 for jv, c in enumerate(contours): if len(np.shape(c)) == 3: @@ -68,8 +69,16 @@ def filter_contours_area_of_image(image, contours, order_index, max_area, min_ar if area >= min_area * np.prod(image.shape[:2]) and area <= max_area * np.prod(image.shape[:2]): # and hierarchy[0][jv][3]==-1 : found_polygons_early.append(np.array([[point] for point in polygon.exterior.coords], dtype=np.uint)) order_index_filtered.append(order_index[jv]) + if min_early: + if area < min_early * np.prod(image.shape[:2]) and area <= max_area * np.prod(image.shape[:2]): # and hierarchy[0][jv][3]==-1 : + regions_ar_less_than_early_min.append(1) + else: + regions_ar_less_than_early_min.append(0) + else: + regions_ar_less_than_early_min = None + #jv += 1 - return found_polygons_early, order_index_filtered + return found_polygons_early, order_index_filtered, regions_ar_less_than_early_min def return_contours_of_interested_region(region_pre_p, pixel, min_area=0.0002): From 5694d971c5c068413b0a35db1aceabd50963107d Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Mon, 5 May 2025 15:39:05 +0200 Subject: [PATCH 106/123] saving model by steps is added to reading order and pixel wise segmentation use cases training --- train/train.py | 60 ++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 51 insertions(+), 9 deletions(-) diff --git a/train/train.py b/train/train.py index 5dee567..df600a8 100644 --- a/train/train.py +++ b/train/train.py @@ -13,8 +13,29 @@ from tensorflow.keras.models import load_model from tqdm import tqdm import json from sklearn.metrics import f1_score +from tensorflow.keras.callbacks import Callback +class SaveWeightsAfterSteps(Callback): + def __init__(self, save_interval, save_path, _config): + super(SaveWeightsAfterSteps, self).__init__() + self.save_interval = save_interval + self.save_path = save_path + self.step_count = 0 + def on_train_batch_end(self, batch, logs=None): + self.step_count += 1 + + if self.step_count % self.save_interval ==0: + save_file = f"{self.save_path}/model_step_{self.step_count}" + #os.system('mkdir '+save_file) + + self.model.save(save_file) + + with open(os.path.join(os.path.join(save_path, "model_step_{self.step_count}"),"config.json"), "w") as fp: + json.dump(_config, fp) # encode dict into JSON + print(f"saved model as steps {self.step_count} to {save_file}") + + def configuration(): config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True @@ -93,7 +114,7 @@ def config_params(): f1_threshold_classification = None # This threshold is used to consider models with an evaluation f1 scores bigger than it. The selected model weights undergo a weights ensembling. And avreage ensembled model will be written to output. classification_classes_name = None # Dictionary of classification classes names. backbone_type = None # As backbone we have 2 types of backbones. A vision transformer alongside a CNN and we call it "transformer" and only CNN called "nontransformer" - + save_interval = None dir_img_bin = None number_of_backgrounds_per_image = 1 dir_rgb_backgrounds = None @@ -112,7 +133,7 @@ def run(_config, n_classes, n_epochs, input_height, thetha, scaling_flip, continue_training, transformer_projection_dim, transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_cnn_first, transformer_patchsize_x, transformer_patchsize_y, - transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output, + transformer_num_patches_xy, backbone_type, save_interval, flip_index, dir_eval, dir_output, pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds): if dir_rgb_backgrounds: @@ -299,13 +320,27 @@ def run(_config, n_classes, n_epochs, input_height, ##img_validation_patches = os.listdir(dir_flow_eval_imgs) ##score_best=[] ##score_best.append(0) + + if save_interval: + save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config) + + for i in tqdm(range(index_start, n_epochs + index_start)): - model.fit( - train_gen, - steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1, - validation_data=val_gen, - validation_steps=1, - epochs=1) + if save_interval: + model.fit( + train_gen, + steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1, + validation_data=val_gen, + validation_steps=1, + epochs=1, callbacks=[save_weights_callback]) + else: + model.fit( + train_gen, + steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1, + validation_data=val_gen, + validation_steps=1, + epochs=1) + model.save(os.path.join(dir_output,'model_'+str(i))) with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp: @@ -392,8 +427,15 @@ def run(_config, n_classes, n_epochs, input_height, opt_adam = tf.keras.optimizers.Adam(learning_rate=0.0001) model.compile(loss="binary_crossentropy", optimizer = opt_adam,metrics=['accuracy']) + + if save_interval: + save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config) + for i in range(n_epochs): - history = model.fit(generate_arrays_from_folder_reading_order(dir_flow_train_labels, dir_flow_train_imgs, n_batch, input_height, input_width, n_classes, thetha, augmentation), steps_per_epoch=num_rows / n_batch, verbose=1) + if save_interval: + history = model.fit(generate_arrays_from_folder_reading_order(dir_flow_train_labels, dir_flow_train_imgs, n_batch, input_height, input_width, n_classes, thetha, augmentation), steps_per_epoch=num_rows / n_batch, verbose=1, callbacks=[save_weights_callback]) + else: + history = model.fit(generate_arrays_from_folder_reading_order(dir_flow_train_labels, dir_flow_train_imgs, n_batch, input_height, input_width, n_classes, thetha, augmentation), steps_per_epoch=num_rows / n_batch, verbose=1) model.save( os.path.join(dir_output,'model_'+str(i+indexer_start) )) with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp: From 92954b1b7b7363f8cdae91500cf0e729c2eebc62 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Mon, 5 May 2025 16:13:38 +0200 Subject: [PATCH 107/123] resolving issued with saving model by steps --- train/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/train/train.py b/train/train.py index df600a8..f6a4f47 100644 --- a/train/train.py +++ b/train/train.py @@ -21,6 +21,7 @@ class SaveWeightsAfterSteps(Callback): self.save_interval = save_interval self.save_path = save_path self.step_count = 0 + self._config = _config def on_train_batch_end(self, batch, logs=None): self.step_count += 1 @@ -31,8 +32,8 @@ class SaveWeightsAfterSteps(Callback): self.model.save(save_file) - with open(os.path.join(os.path.join(save_path, "model_step_{self.step_count}"),"config.json"), "w") as fp: - json.dump(_config, fp) # encode dict into JSON + with open(os.path.join(os.path.join(self.save_path, f"model_step_{self.step_count}"),"config.json"), "w") as fp: + json.dump(self._config, fp) # encode dict into JSON print(f"saved model as steps {self.step_count} to {save_file}") From 6fa766d6a566fa4660c0c7424ddebb85f1a0d0c7 Mon Sep 17 00:00:00 2001 From: johnlockejrr <16368414+johnlockejrr@users.noreply.github.com> Date: Sun, 11 May 2025 05:31:34 -0700 Subject: [PATCH 108/123] Update utils.py --- train/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train/utils.py b/train/utils.py index 3d42b64..cba20c2 100644 --- a/train/utils.py +++ b/train/utils.py @@ -667,7 +667,7 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow indexer = 0 for im, seg_i in tqdm(zip(imgs_list_train, segs_list_train)): - img_name = im.split('.')[0] + img_name = os.path.splitext(im)[0] if task == "segmentation" or task == "binarization": dir_of_label_file = os.path.join(dir_seg, img_name + '.png') elif task=="enhancement": From 3a9fc0efde07a4890995adbfefc8d135e9278747 Mon Sep 17 00:00:00 2001 From: johnlockejrr <16368414+johnlockejrr@users.noreply.github.com> Date: Sun, 11 May 2025 06:09:17 -0700 Subject: [PATCH 109/123] Update utils.py Changed unsafe basename extraction: `file_name = i.split('.')[0]` to `file_name = os.path.splitext(i)[0]` and `filename = n[i].split('.')[0]` to `filename = os.path.splitext(n[i])[0]` because `"Vat.sam.2_206.jpg` -> `Vat` instead of `"Vat.sam.2_206` --- train/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train/utils.py b/train/utils.py index cba20c2..bbe21d1 100644 --- a/train/utils.py +++ b/train/utils.py @@ -374,7 +374,7 @@ def generate_arrays_from_folder_reading_order(classes_file_dir, modal_dir, batch batchcount = 0 while True: for i in all_labels_files: - file_name = i.split('.')[0] + file_name = os.path.splitext(i)[0] img = cv2.imread(os.path.join(modal_dir,file_name+'.png')) label_class = int( np.load(os.path.join(classes_file_dir,i)) ) @@ -401,7 +401,7 @@ def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_c for i in range(c, c + batch_size): # initially from 0 to 16, c = 0. try: - filename = n[i].split('.')[0] + filename = os.path.splitext(n[i])[0] train_img = cv2.imread(img_folder + '/' + n[i]) / 255. train_img = cv2.resize(train_img, (input_width, input_height), From 4ddc84dee87ed7e1b600592ba8e96cad93e653e3 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Mon, 12 May 2025 18:31:40 +0200 Subject: [PATCH 110/123] visulizing textline detection from eynollah page-xml output --- train/generate_gt_for_training.py | 48 +++++++++++++++++ train/gt_gen_utils.py | 88 +++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+) diff --git a/train/generate_gt_for_training.py b/train/generate_gt_for_training.py index 77e9238..9ce743a 100644 --- a/train/generate_gt_for_training.py +++ b/train/generate_gt_for_training.py @@ -2,6 +2,7 @@ import click import json from gt_gen_utils import * from tqdm import tqdm +from pathlib import Path @click.group() def main(): @@ -331,6 +332,53 @@ def visualize_reading_order(dir_xml, dir_out, dir_imgs): cv2.imwrite(os.path.join(dir_out, f_name+'.png'), img) +@main.command() +@click.option( + "--dir_xml", + "-dx", + help="directory of GT page-xml files", + type=click.Path(exists=True, file_okay=False), +) + +@click.option( + "--dir_out", + "-do", + help="directory where plots will be written", + type=click.Path(exists=True, file_okay=False), +) + +@click.option( + "--dir_imgs", + "-dimg", + help="directory of images where textline segmentation will be overlayed", ) + +def visualize_textline_segmentation(dir_xml, dir_out, dir_imgs): + xml_files_ind = os.listdir(dir_xml) + for ind_xml in tqdm(xml_files_ind): + indexer = 0 + #print(ind_xml) + #print('########################') + xml_file = os.path.join(dir_xml,ind_xml ) + f_name = Path(ind_xml).stem + + img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name) + img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format)) + + co_tetxlines, y_len, x_len = get_textline_contours_for_visualization(xml_file) + + img_total = np.zeros((y_len, x_len, 3)) + for cont in co_tetxlines: + img_in = np.zeros((y_len, x_len, 3)) + img_in = cv2.fillPoly(img_in, pts =[cont], color=(1,1,1)) + + img_total = img_total + img_in + + img_total[:,:, 0][img_total[:,:, 0]>2] = 2 + + img_out, _ = visualize_model_output(img_total, img, task="textline") + + cv2.imwrite(os.path.join(dir_out, f_name+'.png'), img_out) + if __name__ == "__main__": main() diff --git a/train/gt_gen_utils.py b/train/gt_gen_utils.py index 10183d6..0a65f05 100644 --- a/train/gt_gen_utils.py +++ b/train/gt_gen_utils.py @@ -16,6 +16,52 @@ KERNEL = np.ones((5, 5), np.uint8) with warnings.catch_warnings(): warnings.simplefilter("ignore") +def visualize_model_output(prediction, img, task): + if task == "binarization": + prediction = prediction * -1 + prediction = prediction + 1 + added_image = prediction * 255 + layout_only = None + else: + unique_classes = np.unique(prediction[:,:,0]) + rgb_colors = {'0' : [255, 255, 255], + '1' : [255, 0, 0], + '2' : [255, 125, 0], + '3' : [255, 0, 125], + '4' : [125, 125, 125], + '5' : [125, 125, 0], + '6' : [0, 125, 255], + '7' : [0, 125, 0], + '8' : [125, 125, 125], + '9' : [0, 125, 255], + '10' : [125, 0, 125], + '11' : [0, 255, 0], + '12' : [0, 0, 255], + '13' : [0, 255, 255], + '14' : [255, 125, 125], + '15' : [255, 0, 255]} + + layout_only = np.zeros(prediction.shape) + + for unq_class in unique_classes: + rgb_class_unique = rgb_colors[str(int(unq_class))] + layout_only[:,:,0][prediction[:,:,0]==unq_class] = rgb_class_unique[0] + layout_only[:,:,1][prediction[:,:,0]==unq_class] = rgb_class_unique[1] + layout_only[:,:,2][prediction[:,:,0]==unq_class] = rgb_class_unique[2] + + + + img = resize_image(img, layout_only.shape[0], layout_only.shape[1]) + + layout_only = layout_only.astype(np.int32) + img = img.astype(np.int32) + + + + added_image = cv2.addWeighted(img,0.5,layout_only,0.1,0) + + return added_image, layout_only + def get_content_of_dir(dir_in): """ Listing all ground truth page xml files. All files are needed to have xml format. @@ -138,6 +184,48 @@ def update_region_contours(co_text, img_boundary, erosion_rate, dilation_rate, y img_boundary[:,:][boundary[:,:]==1] =1 return co_text_eroded, img_boundary + +def get_textline_contours_for_visualization(xml_file): + tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding = 'iso-8859-5')) + root1=tree1.getroot() + alltags=[elem.tag for elem in root1.iter()] + link=alltags[0].split('}')[0]+'}' + + + + for jj in root1.iter(link+'Page'): + y_len=int(jj.attrib['imageHeight']) + x_len=int(jj.attrib['imageWidth']) + + region_tags = np.unique([x for x in alltags if x.endswith('TextLine')]) + tag_endings = ['}TextLine','}textline'] + co_use_case = [] + + for tag in region_tags: + if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]): + for nn in root1.iter(tag): + c_t_in = [] + sumi = 0 + for vv in nn.iter(): + if vv.tag == link + 'Coords': + coords = bool(vv.attrib) + if coords: + p_h = vv.attrib['points'].split(' ') + c_t_in.append( + np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])) + break + else: + pass + + if vv.tag == link + 'Point': + c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))]) + sumi += 1 + elif vv.tag != link + 'Point' and sumi >= 1: + break + co_use_case.append(np.array(c_t_in)) + return co_use_case, y_len, x_len + + def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_file, config_params, printspace, dir_images, dir_out_images): """ Reading the page xml files and write the ground truth images into given output directory. From 4a7728bb346aeccf76a34a6e0ec900e4df40a765 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Mon, 12 May 2025 22:39:47 +0200 Subject: [PATCH 111/123] visuliazation layout from eynollah page-xml output --- train/generate_gt_for_training.py | 53 ++++- train/gt_gen_utils.py | 312 ++++++++++++++++++++++++++++++ 2 files changed, 355 insertions(+), 10 deletions(-) diff --git a/train/generate_gt_for_training.py b/train/generate_gt_for_training.py index 9ce743a..7e7c6a0 100644 --- a/train/generate_gt_for_training.py +++ b/train/generate_gt_for_training.py @@ -366,18 +366,51 @@ def visualize_textline_segmentation(dir_xml, dir_out, dir_imgs): co_tetxlines, y_len, x_len = get_textline_contours_for_visualization(xml_file) - img_total = np.zeros((y_len, x_len, 3)) - for cont in co_tetxlines: - img_in = np.zeros((y_len, x_len, 3)) - img_in = cv2.fillPoly(img_in, pts =[cont], color=(1,1,1)) - - img_total = img_total + img_in - - img_total[:,:, 0][img_total[:,:, 0]>2] = 2 + added_image = visualize_image_from_contours(co_tetxlines, img) - img_out, _ = visualize_model_output(img_total, img, task="textline") + cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image) + - cv2.imwrite(os.path.join(dir_out, f_name+'.png'), img_out) + +@main.command() +@click.option( + "--dir_xml", + "-dx", + help="directory of GT page-xml files", + type=click.Path(exists=True, file_okay=False), +) + +@click.option( + "--dir_out", + "-do", + help="directory where plots will be written", + type=click.Path(exists=True, file_okay=False), +) + +@click.option( + "--dir_imgs", + "-dimg", + help="directory of images where textline segmentation will be overlayed", ) + +def visualize_layout_segmentation(dir_xml, dir_out, dir_imgs): + xml_files_ind = os.listdir(dir_xml) + for ind_xml in tqdm(xml_files_ind): + indexer = 0 + #print(ind_xml) + #print('########################') + xml_file = os.path.join(dir_xml,ind_xml ) + f_name = Path(ind_xml).stem + + img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name) + img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format)) + + co_text, co_graphic, co_sep, co_img, co_table, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file) + + + added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], img) + + cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image) + if __name__ == "__main__": diff --git a/train/gt_gen_utils.py b/train/gt_gen_utils.py index 0a65f05..9b67563 100644 --- a/train/gt_gen_utils.py +++ b/train/gt_gen_utils.py @@ -15,6 +15,63 @@ KERNEL = np.ones((5, 5), np.uint8) with warnings.catch_warnings(): warnings.simplefilter("ignore") + + +def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_image, co_marginal, img): + alpha = 0.5 + + blank_image = np.ones( (img.shape[:]), dtype=np.uint8) * 255 + + col_header = (173, 216, 230) + col_drop = (0, 191, 255) + boundary_color = (143, 216, 200)#(0, 0, 255) # Dark gray for the boundary + col_par = (0, 0, 139) # Lighter gray for the filled area + col_image = (0, 100, 0) + col_sep = (255, 0, 0) + col_marginal = (106, 90, 205) + + if len(co_image)>0: + cv2.drawContours(blank_image, co_image, -1, col_image, thickness=cv2.FILLED) # Fill the contour + + if len(co_sep)>0: + cv2.drawContours(blank_image, co_sep, -1, col_sep, thickness=cv2.FILLED) # Fill the contour + + + if len(co_header)>0: + cv2.drawContours(blank_image, co_header, -1, col_header, thickness=cv2.FILLED) # Fill the contour + + if len(co_par)>0: + cv2.drawContours(blank_image, co_par, -1, col_par, thickness=cv2.FILLED) # Fill the contour + + cv2.drawContours(blank_image, co_par, -1, boundary_color, thickness=1) # Draw the boundary + + if len(co_drop)>0: + cv2.drawContours(blank_image, co_drop, -1, col_drop, thickness=cv2.FILLED) # Fill the contour + + if len(co_marginal)>0: + cv2.drawContours(blank_image, co_marginal, -1, col_marginal, thickness=cv2.FILLED) # Fill the contour + + img_final =cv2.cvtColor(blank_image, cv2.COLOR_BGR2RGB) + + added_image = cv2.addWeighted(img,alpha,img_final,1- alpha,0) + return added_image + + +def visualize_image_from_contours(contours, img): + alpha = 0.5 + + blank_image = np.ones( (img.shape[:]), dtype=np.uint8) * 255 + + boundary_color = (0, 0, 255) # Dark gray for the boundary + fill_color = (173, 216, 230) # Lighter gray for the filled area + + cv2.drawContours(blank_image, contours, -1, fill_color, thickness=cv2.FILLED) # Fill the contour + cv2.drawContours(blank_image, contours, -1, boundary_color, thickness=1) # Draw the boundary + + img_final =cv2.cvtColor(blank_image, cv2.COLOR_BGR2RGB) + + added_image = cv2.addWeighted(img,alpha,img_final,1- alpha,0) + return added_image def visualize_model_output(prediction, img, task): if task == "binarization": @@ -224,7 +281,262 @@ def get_textline_contours_for_visualization(xml_file): break co_use_case.append(np.array(c_t_in)) return co_use_case, y_len, x_len + + +def get_layout_contours_for_visualization(xml_file): + tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding = 'iso-8859-5')) + root1=tree1.getroot() + alltags=[elem.tag for elem in root1.iter()] + link=alltags[0].split('}')[0]+'}' + + + + for jj in root1.iter(link+'Page'): + y_len=int(jj.attrib['imageHeight']) + x_len=int(jj.attrib['imageWidth']) + + region_tags=np.unique([x for x in alltags if x.endswith('Region')]) + co_text = {'drop-capital':[], "footnote":[], "footnote-continued":[], "heading":[], "signature-mark":[], "header":[], "catch-word":[], "page-number":[], "marginalia":[], "paragraph":[]} + all_defined_textregion_types = list(co_text.keys()) + co_graphic = {"handwritten-annotation":[], "decoration":[], "stamp":[], "signature":[]} + all_defined_graphic_types = list(co_graphic.keys()) + co_sep=[] + co_img=[] + co_table=[] + co_noise=[] + types_text = [] + + for tag in region_tags: + if tag.endswith('}TextRegion') or tag.endswith('}Textregion'): + for nn in root1.iter(tag): + c_t_in = {'drop-capital':[], "footnote":[], "footnote-continued":[], "heading":[], "signature-mark":[], "header":[], "catch-word":[], "page-number":[], "marginalia":[], "paragraph":[]} + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + + if "rest_as_paragraph" in types_text: + types_text_without_paragraph = [element for element in types_text if element!='rest_as_paragraph' and element!='paragraph'] + if len(types_text_without_paragraph) == 0: + if "type" in nn.attrib: + c_t_in['paragraph'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + elif len(types_text_without_paragraph) >= 1: + if "type" in nn.attrib: + if nn.attrib['type'] in types_text_without_paragraph: + c_t_in[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + else: + c_t_in['paragraph'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + else: + if "type" in nn.attrib: + if nn.attrib['type'] in all_defined_textregion_types: + c_t_in[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + break + else: + pass + + + if vv.tag==link+'Point': + if "rest_as_paragraph" in types_text: + types_text_without_paragraph = [element for element in types_text if element!='rest_as_paragraph' and element!='paragraph'] + if len(types_text_without_paragraph) == 0: + if "type" in nn.attrib: + c_t_in['paragraph'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + elif len(types_text_without_paragraph) >= 1: + if "type" in nn.attrib: + if nn.attrib['type'] in types_text_without_paragraph: + c_t_in[nn.attrib['type']].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + else: + c_t_in['paragraph'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + + else: + if "type" in nn.attrib: + if nn.attrib['type'] in all_defined_textregion_types: + c_t_in[nn.attrib['type']].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + + + elif vv.tag!=link+'Point' and sumi>=1: + break + + for element_text in list(c_t_in.keys()): + if len(c_t_in[element_text])>0: + co_text[element_text].append(np.array(c_t_in[element_text])) + + + if tag.endswith('}GraphicRegion') or tag.endswith('}graphicregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in_graphic = {"handwritten-annotation":[], "decoration":[], "stamp":[], "signature":[]} + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + + if "rest_as_decoration" in types_graphic: + types_graphic_without_decoration = [element for element in types_graphic if element!='rest_as_decoration' and element!='decoration'] + if len(types_graphic_without_decoration) == 0: + if "type" in nn.attrib: + c_t_in_graphic['decoration'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + elif len(types_graphic_without_decoration) >= 1: + if "type" in nn.attrib: + if nn.attrib['type'] in types_graphic_without_decoration: + c_t_in_graphic[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + else: + c_t_in_graphic['decoration'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + else: + if "type" in nn.attrib: + if nn.attrib['type'] in all_defined_graphic_types: + c_t_in_graphic[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + + break + else: + pass + + + if vv.tag==link+'Point': + if "rest_as_decoration" in types_graphic: + types_graphic_without_decoration = [element for element in types_graphic if element!='rest_as_decoration' and element!='decoration'] + if len(types_graphic_without_decoration) == 0: + if "type" in nn.attrib: + c_t_in_graphic['decoration'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + elif len(types_graphic_without_decoration) >= 1: + if "type" in nn.attrib: + if nn.attrib['type'] in types_graphic_without_decoration: + c_t_in_graphic[nn.attrib['type']].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + else: + c_t_in_graphic['decoration'].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + + else: + if "type" in nn.attrib: + if nn.attrib['type'] in all_defined_graphic_types: + c_t_in_graphic[nn.attrib['type']].append( [ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ] ) + sumi+=1 + + elif vv.tag!=link+'Point' and sumi>=1: + break + + for element_graphic in list(c_t_in_graphic.keys()): + if len(c_t_in_graphic[element_graphic])>0: + co_graphic[element_graphic].append(np.array(c_t_in_graphic[element_graphic])) + + + if tag.endswith('}ImageRegion') or tag.endswith('}imageregion'): + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + + elif vv.tag!=link+'Point' and sumi>=1: + break + co_img.append(np.array(c_t_in)) + + + if tag.endswith('}SeparatorRegion') or tag.endswith('}separatorregion'): + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + + elif vv.tag!=link+'Point' and sumi>=1: + break + co_sep.append(np.array(c_t_in)) + + + if tag.endswith('}TableRegion') or tag.endswith('}tableregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_table.append(np.array(c_t_in)) + + + if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_noise.append(np.array(c_t_in)) + return co_text, co_graphic, co_sep, co_img, co_table, co_noise, y_len, x_len def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_file, config_params, printspace, dir_images, dir_out_images): """ From 25abc0fabc8a70b6a9c21c35006c08fec577d792 Mon Sep 17 00:00:00 2001 From: johnlockejrr <16368414+johnlockejrr@users.noreply.github.com> Date: Wed, 14 May 2025 03:34:51 -0700 Subject: [PATCH 112/123] Update gt_gen_utils.py Keep safely the full basename without extension --- train/gt_gen_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/train/gt_gen_utils.py b/train/gt_gen_utils.py index 5784e14..8837462 100644 --- a/train/gt_gen_utils.py +++ b/train/gt_gen_utils.py @@ -22,7 +22,7 @@ def get_content_of_dir(dir_in): """ gt_all=os.listdir(dir_in) - gt_list=[file for file in gt_all if file.split('.')[ len(file.split('.'))-1 ]=='xml' ] + gt_list = [file for file in gt_all if os.path.splitext(file)[1] == '.xml'] return gt_list def return_parent_contours(contours, hierarchy): @@ -134,7 +134,7 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ if dir_images: ls_org_imgs = os.listdir(dir_images) - ls_org_imgs_stem = [item.split('.')[0] for item in ls_org_imgs] + ls_org_imgs_stem = [os.path.splitext(item)[0] for item in ls_org_imgs] for index in tqdm(range(len(gt_list))): #try: tree1 = ET.parse(dir_in+'/'+gt_list[index], parser = ET.XMLParser(encoding = 'iso-8859-5')) @@ -298,10 +298,10 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ img_poly = resize_image(img_poly, y_new, x_new) try: - xml_file_stem = gt_list[index].split('-')[1].split('.')[0] + xml_file_stem = os.path.splitext(gt_list[index])[0] cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly) except: - xml_file_stem = gt_list[index].split('.')[0] + xml_file_stem = os.path.splitext(gt_list[index])[0] cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly) if dir_images: @@ -757,10 +757,10 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ img_poly = resize_image(img_poly, y_new, x_new) try: - xml_file_stem = gt_list[index].split('-')[1].split('.')[0] + xml_file_stem = os.path.splitext(gt_list[index])[0] cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly) except: - xml_file_stem = gt_list[index].split('.')[0] + xml_file_stem = os.path.splitext(gt_list[index])[0] cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly) From f9390c71e7ec3c577e80ad4a8894417481407f02 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Sat, 17 May 2025 02:18:27 +0200 Subject: [PATCH 113/123] updating inference for mb reading order --- train/gt_gen_utils.py | 2 +- train/inference.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/train/gt_gen_utils.py b/train/gt_gen_utils.py index 9b67563..a734020 100644 --- a/train/gt_gen_utils.py +++ b/train/gt_gen_utils.py @@ -154,7 +154,7 @@ def filter_contours_area_of_image_tables(image, contours, hierarchy, max_area, m jv += 1 return found_polygons_early -def filter_contours_area_of_image(image, contours, order_index, max_area, min_area, min_early): +def filter_contours_area_of_image(image, contours, order_index, max_area, min_area, min_early=None): found_polygons_early = list() order_index_filtered = list() regions_ar_less_than_early_min = list() diff --git a/train/inference.py b/train/inference.py index db3b31f..aecd0e6 100644 --- a/train/inference.py +++ b/train/inference.py @@ -267,7 +267,7 @@ class sbb_predict: #print(np.shape(co_text_all[0]), len( np.shape(co_text_all[0]) ),'co_text_all') #co_text_all = filter_contours_area_of_image_tables(img_poly, co_text_all, _, max_area, min_area) #print(co_text_all,'co_text_all') - co_text_all, texts_corr_order_index_int = filter_contours_area_of_image(img_poly, co_text_all, texts_corr_order_index_int, max_area, self.min_area) + co_text_all, texts_corr_order_index_int, _ = filter_contours_area_of_image(img_poly, co_text_all, texts_corr_order_index_int, max_area, self.min_area) #print(texts_corr_order_index_int) From 25e3a2a99f4e585ee73d39e981897062ccd13a1e Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Fri, 23 May 2025 18:30:51 +0200 Subject: [PATCH 114/123] visualizing ro for single xml file --- train/generate_gt_for_training.py | 53 +++++++++++++++++++++++++------ 1 file changed, 44 insertions(+), 9 deletions(-) diff --git a/train/generate_gt_for_training.py b/train/generate_gt_for_training.py index 7e7c6a0..9b7f02b 100644 --- a/train/generate_gt_for_training.py +++ b/train/generate_gt_for_training.py @@ -252,6 +252,12 @@ def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, i @main.command() +@click.option( + "--xml_file", + "-xml", + help="xml filename", + type=click.Path(exists=True, dir_okay=False), +) @click.option( "--dir_xml", "-dx", @@ -271,10 +277,14 @@ def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, i "-dimg", help="directory where the overlayed plots will be written", ) -def visualize_reading_order(dir_xml, dir_out, dir_imgs): - xml_files_ind = os.listdir(dir_xml) - +def visualize_reading_order(xml_file, dir_xml, dir_out, dir_imgs): + assert xml_file or dir_xml, "A single xml file -xml or a dir of xml files -dx is required not both of them" + if dir_xml: + xml_files_ind = os.listdir(dir_xml) + else: + xml_files_ind = [xml_file] + indexer_start= 0#55166 #min_area = 0.0001 @@ -282,8 +292,17 @@ def visualize_reading_order(dir_xml, dir_out, dir_imgs): indexer = 0 #print(ind_xml) #print('########################') - xml_file = os.path.join(dir_xml,ind_xml ) - f_name = ind_xml.split('.')[0] + #xml_file = os.path.join(dir_xml,ind_xml ) + + if dir_xml: + xml_file = os.path.join(dir_xml,ind_xml ) + f_name = Path(ind_xml).stem + else: + xml_file = os.path.join(ind_xml ) + f_name = Path(ind_xml).stem + print(f_name, 'f_name') + + #f_name = ind_xml.split('.')[0] _, _, _, file_name, id_paragraph, id_header,co_text_paragraph,co_text_header,tot_region_ref,x_len, y_len,index_tot_regions,img_poly = read_xml(xml_file) id_all_text = id_paragraph + id_header @@ -373,6 +392,12 @@ def visualize_textline_segmentation(dir_xml, dir_out, dir_imgs): @main.command() +@click.option( + "--xml_file", + "-xml", + help="xml filename", + type=click.Path(exists=True, dir_okay=False), +) @click.option( "--dir_xml", "-dx", @@ -392,14 +417,24 @@ def visualize_textline_segmentation(dir_xml, dir_out, dir_imgs): "-dimg", help="directory of images where textline segmentation will be overlayed", ) -def visualize_layout_segmentation(dir_xml, dir_out, dir_imgs): - xml_files_ind = os.listdir(dir_xml) +def visualize_layout_segmentation(xml_file, dir_xml, dir_out, dir_imgs): + assert xml_file and dir_xml, "A single xml file -xml or a dir of xml files -dx is required not both of them" + if dir_xml: + xml_files_ind = os.listdir(dir_xml) + else: + xml_files_ind = [xml_file] + for ind_xml in tqdm(xml_files_ind): indexer = 0 #print(ind_xml) #print('########################') - xml_file = os.path.join(dir_xml,ind_xml ) - f_name = Path(ind_xml).stem + if dir_xml: + xml_file = os.path.join(dir_xml,ind_xml ) + f_name = Path(ind_xml).stem + else: + xml_file = os.path.join(ind_xml ) + f_name = Path(ind_xml).stem + print(f_name, 'f_name') img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name) img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format)) From eb91000490282e2ea0d6058032f69f29da7783b6 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Mon, 2 Jun 2025 18:23:34 +0200 Subject: [PATCH 115/123] layout visualization updated --- train/generate_gt_for_training.py | 4 ++-- train/gt_gen_utils.py | 9 +++++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/train/generate_gt_for_training.py b/train/generate_gt_for_training.py index 9b7f02b..8ca5cd3 100644 --- a/train/generate_gt_for_training.py +++ b/train/generate_gt_for_training.py @@ -418,7 +418,7 @@ def visualize_textline_segmentation(dir_xml, dir_out, dir_imgs): help="directory of images where textline segmentation will be overlayed", ) def visualize_layout_segmentation(xml_file, dir_xml, dir_out, dir_imgs): - assert xml_file and dir_xml, "A single xml file -xml or a dir of xml files -dx is required not both of them" + assert xml_file or dir_xml, "A single xml file -xml or a dir of xml files -dx is required not both of them" if dir_xml: xml_files_ind = os.listdir(dir_xml) else: @@ -442,7 +442,7 @@ def visualize_layout_segmentation(xml_file, dir_xml, dir_out, dir_imgs): co_text, co_graphic, co_sep, co_img, co_table, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file) - added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], img) + added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], img) cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image) diff --git a/train/gt_gen_utils.py b/train/gt_gen_utils.py index a734020..0ac15a2 100644 --- a/train/gt_gen_utils.py +++ b/train/gt_gen_utils.py @@ -306,6 +306,7 @@ def get_layout_contours_for_visualization(xml_file): co_noise=[] types_text = [] + types_graphic = [] for tag in region_tags: if tag.endswith('}TextRegion') or tag.endswith('}Textregion'): @@ -325,6 +326,9 @@ def get_layout_contours_for_visualization(xml_file): if len(types_text_without_paragraph) == 0: if "type" in nn.attrib: c_t_in['paragraph'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + else: + c_t_in['paragraph'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + elif len(types_text_without_paragraph) >= 1: if "type" in nn.attrib: if nn.attrib['type'] in types_text_without_paragraph: @@ -332,10 +336,15 @@ def get_layout_contours_for_visualization(xml_file): else: c_t_in['paragraph'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + else: + c_t_in['paragraph'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + else: if "type" in nn.attrib: if nn.attrib['type'] in all_defined_textregion_types: c_t_in[nn.attrib['type']].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + else: + c_t_in['paragraph'].append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) break else: From f5a1d1a255a080469ba4624d7912b6e5e4cc7647 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Wed, 25 Jun 2025 18:24:16 +0200 Subject: [PATCH 116/123] docker file to train model with desired cuda and cudnn --- train/Dockerfile | 29 ++++++++++++++++++ train/config_params_docker.json | 54 +++++++++++++++++++++++++++++++++ train/train.py | 2 +- 3 files changed, 84 insertions(+), 1 deletion(-) create mode 100644 train/Dockerfile create mode 100644 train/config_params_docker.json diff --git a/train/Dockerfile b/train/Dockerfile new file mode 100644 index 0000000..2456ea4 --- /dev/null +++ b/train/Dockerfile @@ -0,0 +1,29 @@ +# Use NVIDIA base image +FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 + +# Set the working directory +WORKDIR /app + + +# Set environment variable for GitPython +ENV GIT_PYTHON_REFRESH=quiet + +# Install Python and pip +RUN apt-get update && apt-get install -y --fix-broken && \ + apt-get install -y \ + python3 \ + python3-pip \ + python3-distutils \ + python3-setuptools \ + python3-wheel && \ + rm -rf /var/lib/apt/lists/* + +# Copy and install Python dependencies +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Copy the rest of the application +COPY . . + +# Specify the entry point +CMD ["python3", "train.py", "with", "config_params_docker.json"] diff --git a/train/config_params_docker.json b/train/config_params_docker.json new file mode 100644 index 0000000..45f87d3 --- /dev/null +++ b/train/config_params_docker.json @@ -0,0 +1,54 @@ +{ + "backbone_type" : "nontransformer", + "task": "segmentation", + "n_classes" : 3, + "n_epochs" : 1, + "input_height" : 672, + "input_width" : 448, + "weight_decay" : 1e-6, + "n_batch" : 4, + "learning_rate": 1e-4, + "patches" : false, + "pretraining" : true, + "augmentation" : false, + "flip_aug" : false, + "blur_aug" : true, + "scaling" : true, + "adding_rgb_background": false, + "adding_rgb_foreground": false, + "add_red_textlines": false, + "channels_shuffling": true, + "degrading": true, + "brightening": true, + "binarization" : false, + "scaling_bluring" : false, + "scaling_binarization" : false, + "scaling_flip" : false, + "rotation": false, + "rotation_not_90": true, + "transformer_num_patches_xy": [14, 21], + "transformer_patchsize_x": 1, + "transformer_patchsize_y": 1, + "transformer_projection_dim": 64, + "transformer_mlp_head_units": [128, 64], + "transformer_layers": 1, + "transformer_num_heads": 1, + "transformer_cnn_first": true, + "blur_k" : ["blur","gauss","median"], + "scales" : [0.6, 0.7, 0.8, 0.9], + "brightness" : [1.3, 1.5, 1.7, 2], + "degrade_scales" : [0.2, 0.4], + "flip_index" : [0, 1, -1], + "shuffle_indexes" : [ [0,2,1], [1,2,0], [1,0,2] , [2,1,0]], + "thetha" : [5, -5], + "number_of_backgrounds_per_image": 2, + "continue_training": false, + "index_start" : 0, + "dir_of_start_model" : " ", + "weighted_loss": false, + "is_loss_soft_dice": true, + "data_is_provided": false, + "dir_train": "/entry_point_dir/train", + "dir_eval": "/entry_point_dir/eval", + "dir_output": "/entry_point_dir/output" +} diff --git a/train/train.py b/train/train.py index f6a4f47..e8e92af 100644 --- a/train/train.py +++ b/train/train.py @@ -53,7 +53,7 @@ def get_dirs_or_files(input_data): return image_input, labels_input -ex = Experiment() +ex = Experiment(save_git_info=False) @ex.config From 1b222594d694884108428d47a74aa67111d40218 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Wed, 25 Jun 2025 18:33:55 +0200 Subject: [PATCH 117/123] Update README.md: how to train model using docker image --- train/README.md | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/train/README.md b/train/README.md index b9e70a8..7c69a10 100644 --- a/train/README.md +++ b/train/README.md @@ -24,7 +24,19 @@ each class will be defined with a RGB value and beside images, a text file of cl ### Train To train a model, run: ``python train.py with config_params.json`` - + +### Train using Docker + +#### Build the Docker image + + ```bash + docker build -t model-training . + ``` +#### Run Docker image + ```bash + docker run --gpus all -v /host/path/to/entry_point_dir:/entry_point_dir model-training + ``` + ### Ground truth format Lables for each pixel are identified by a number. So if you have a binary case, ``n_classes`` should be set to ``2`` and labels should From 6462ea5b33cd6e4c1eaac1b2bf1fe072147e76f9 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Wed, 6 Aug 2025 22:33:42 +0200 Subject: [PATCH 118/123] adding visualization of ocr text of xml file --- train/generate_gt_for_training.py | 81 +++++++++++++++++++++++++++++++ train/gt_gen_utils.py | 71 +++++++++++++++++++++++++++ 2 files changed, 152 insertions(+) diff --git a/train/generate_gt_for_training.py b/train/generate_gt_for_training.py index 8ca5cd3..1971f68 100644 --- a/train/generate_gt_for_training.py +++ b/train/generate_gt_for_training.py @@ -3,6 +3,7 @@ import json from gt_gen_utils import * from tqdm import tqdm from pathlib import Path +from PIL import Image, ImageDraw, ImageFont @click.group() def main(): @@ -447,6 +448,86 @@ def visualize_layout_segmentation(xml_file, dir_xml, dir_out, dir_imgs): cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image) + + +@main.command() +@click.option( + "--xml_file", + "-xml", + help="xml filename", + type=click.Path(exists=True, dir_okay=False), +) +@click.option( + "--dir_xml", + "-dx", + help="directory of GT page-xml files", + type=click.Path(exists=True, file_okay=False), +) + +@click.option( + "--dir_out", + "-do", + help="directory where plots will be written", + type=click.Path(exists=True, file_okay=False), +) + + +def visualize_ocr_text(xml_file, dir_xml, dir_out): + assert xml_file or dir_xml, "A single xml file -xml or a dir of xml files -dx is required not both of them" + if dir_xml: + xml_files_ind = os.listdir(dir_xml) + else: + xml_files_ind = [xml_file] + + font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! + font = ImageFont.truetype(font_path, 40) + + for ind_xml in tqdm(xml_files_ind): + indexer = 0 + #print(ind_xml) + #print('########################') + if dir_xml: + xml_file = os.path.join(dir_xml,ind_xml ) + f_name = Path(ind_xml).stem + else: + xml_file = os.path.join(ind_xml ) + f_name = Path(ind_xml).stem + print(f_name, 'f_name') + + co_tetxlines, y_len, x_len, ocr_texts = get_textline_contours_and_ocr_text(xml_file) + + total_bb_coordinates = [] + + image_text = Image.new("RGB", (x_len, y_len), "white") + draw = ImageDraw.Draw(image_text) + + + + for index, cnt in enumerate(co_tetxlines): + x,y,w,h = cv2.boundingRect(cnt) + #total_bb_coordinates.append([x,y,w,h]) + + #fit_text_single_line + + #x_bb = bb_ind[0] + #y_bb = bb_ind[1] + #w_bb = bb_ind[2] + #h_bb = bb_ind[3] + + font = fit_text_single_line(draw, ocr_texts[index], font_path, w, int(h*0.4) ) + + ##draw.rectangle([x_bb, y_bb, x_bb + w_bb, y_bb + h_bb], outline="red", width=2) + + text_bbox = draw.textbbox((0, 0), ocr_texts[index], font=font) + text_width = text_bbox[2] - text_bbox[0] + text_height = text_bbox[3] - text_bbox[1] + + text_x = x + (w - text_width) // 2 # Center horizontally + text_y = y + (h - text_height) // 2 # Center vertically + + # Draw the text + draw.text((text_x, text_y), ocr_texts[index], fill="black", font=font) + image_text.save(os.path.join(dir_out, f_name+'.png')) if __name__ == "__main__": main() diff --git a/train/gt_gen_utils.py b/train/gt_gen_utils.py index 5076dd6..907e04d 100644 --- a/train/gt_gen_utils.py +++ b/train/gt_gen_utils.py @@ -9,6 +9,7 @@ import cv2 from shapely import geometry from pathlib import Path import matplotlib.pyplot as plt +from PIL import Image, ImageDraw, ImageFont KERNEL = np.ones((5, 5), np.uint8) @@ -283,6 +284,76 @@ def get_textline_contours_for_visualization(xml_file): return co_use_case, y_len, x_len +def get_textline_contours_and_ocr_text(xml_file): + tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding = 'iso-8859-5')) + root1=tree1.getroot() + alltags=[elem.tag for elem in root1.iter()] + link=alltags[0].split('}')[0]+'}' + + + + for jj in root1.iter(link+'Page'): + y_len=int(jj.attrib['imageHeight']) + x_len=int(jj.attrib['imageWidth']) + + region_tags = np.unique([x for x in alltags if x.endswith('TextLine')]) + tag_endings = ['}TextLine','}textline'] + co_use_case = [] + ocr_textlines = [] + + for tag in region_tags: + if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]): + for nn in root1.iter(tag): + c_t_in = [] + ocr_text_in = [''] + sumi = 0 + for vv in nn.iter(): + if vv.tag == link + 'Coords': + for childtest2 in nn: + if childtest2.tag.endswith("TextEquiv"): + for child_uc in childtest2: + if child_uc.tag.endswith("Unicode"): + text = child_uc.text + ocr_text_in[0]= text + + coords = bool(vv.attrib) + if coords: + p_h = vv.attrib['points'].split(' ') + c_t_in.append( + np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])) + break + else: + pass + + + + if vv.tag == link + 'Point': + c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))]) + sumi += 1 + elif vv.tag != link + 'Point' and sumi >= 1: + break + + + co_use_case.append(np.array(c_t_in)) + ocr_textlines.append(ocr_text_in[0]) + return co_use_case, y_len, x_len, ocr_textlines + +def fit_text_single_line(draw, text, font_path, max_width, max_height): + initial_font_size = 50 + font_size = initial_font_size + while font_size > 10: # Minimum font size + font = ImageFont.truetype(font_path, font_size) + text_bbox = draw.textbbox((0, 0), text, font=font) # Get text bounding box + text_width = text_bbox[2] - text_bbox[0] + text_height = text_bbox[3] - text_bbox[1] + + if text_width <= max_width and text_height <= max_height: + return font # Return the best-fitting font + + font_size -= 2 # Reduce font size and retry + + return ImageFont.truetype(font_path, 10) # Smallest font fallback + def get_layout_contours_for_visualization(xml_file): tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding = 'iso-8859-5')) root1=tree1.getroot() From 263da755ef5d1a03f6398d090b02a094025a52aa Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Thu, 7 Aug 2025 10:32:49 +0200 Subject: [PATCH 119/123] loading xmls with UTF-8 encoding --- train/generate_gt_for_training.py | 26 +++++++++++++------------- train/gt_gen_utils.py | 10 +++++----- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/train/generate_gt_for_training.py b/train/generate_gt_for_training.py index 1971f68..d4b58dc 100644 --- a/train/generate_gt_for_training.py +++ b/train/generate_gt_for_training.py @@ -495,7 +495,7 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out): print(f_name, 'f_name') co_tetxlines, y_len, x_len, ocr_texts = get_textline_contours_and_ocr_text(xml_file) - + total_bb_coordinates = [] image_text = Image.new("RGB", (x_len, y_len), "white") @@ -513,20 +513,20 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out): #y_bb = bb_ind[1] #w_bb = bb_ind[2] #h_bb = bb_ind[3] - - font = fit_text_single_line(draw, ocr_texts[index], font_path, w, int(h*0.4) ) - - ##draw.rectangle([x_bb, y_bb, x_bb + w_bb, y_bb + h_bb], outline="red", width=2) - - text_bbox = draw.textbbox((0, 0), ocr_texts[index], font=font) - text_width = text_bbox[2] - text_bbox[0] - text_height = text_bbox[3] - text_bbox[1] + if ocr_texts[index]: + font = fit_text_single_line(draw, ocr_texts[index], font_path, w, int(h*0.4) ) + + ##draw.rectangle([x_bb, y_bb, x_bb + w_bb, y_bb + h_bb], outline="red", width=2) + + text_bbox = draw.textbbox((0, 0), ocr_texts[index], font=font) + text_width = text_bbox[2] - text_bbox[0] + text_height = text_bbox[3] - text_bbox[1] - text_x = x + (w - text_width) // 2 # Center horizontally - text_y = y + (h - text_height) // 2 # Center vertically + text_x = x + (w - text_width) // 2 # Center horizontally + text_y = y + (h - text_height) // 2 # Center vertically - # Draw the text - draw.text((text_x, text_y), ocr_texts[index], fill="black", font=font) + # Draw the text + draw.text((text_x, text_y), ocr_texts[index], fill="black", font=font) image_text.save(os.path.join(dir_out, f_name+'.png')) if __name__ == "__main__": diff --git a/train/gt_gen_utils.py b/train/gt_gen_utils.py index 907e04d..753b0f5 100644 --- a/train/gt_gen_utils.py +++ b/train/gt_gen_utils.py @@ -244,7 +244,7 @@ def update_region_contours(co_text, img_boundary, erosion_rate, dilation_rate, y return co_text_eroded, img_boundary def get_textline_contours_for_visualization(xml_file): - tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding = 'iso-8859-5')) + tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8')) root1=tree1.getroot() alltags=[elem.tag for elem in root1.iter()] link=alltags[0].split('}')[0]+'}' @@ -285,7 +285,7 @@ def get_textline_contours_for_visualization(xml_file): def get_textline_contours_and_ocr_text(xml_file): - tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding = 'iso-8859-5')) + tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8')) root1=tree1.getroot() alltags=[elem.tag for elem in root1.iter()] link=alltags[0].split('}')[0]+'}' @@ -355,7 +355,7 @@ def fit_text_single_line(draw, text, font_path, max_width, max_height): return ImageFont.truetype(font_path, 10) # Smallest font fallback def get_layout_contours_for_visualization(xml_file): - tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding = 'iso-8859-5')) + tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8')) root1=tree1.getroot() alltags=[elem.tag for elem in root1.iter()] link=alltags[0].split('}')[0]+'}' @@ -630,7 +630,7 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ for index in tqdm(range(len(gt_list))): #try: print(gt_list[index]) - tree1 = ET.parse(dir_in+'/'+gt_list[index], parser = ET.XMLParser(encoding = 'iso-8859-5')) + tree1 = ET.parse(dir_in+'/'+gt_list[index], parser = ET.XMLParser(encoding='utf-8')) root1=tree1.getroot() alltags=[elem.tag for elem in root1.iter()] link=alltags[0].split('}')[0]+'}' @@ -1311,7 +1311,7 @@ def find_new_features_of_contours(contours_main): return cx_main, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, y_corr_x_min_from_argmin def read_xml(xml_file): file_name = Path(xml_file).stem - tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding = 'iso-8859-5')) + tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8')) root1=tree1.getroot() alltags=[elem.tag for elem in root1.iter()] link=alltags[0].split('}')[0]+'}' From cf4983da54a1d8e0e5e382569a5502110b438189 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Fri, 8 Aug 2025 16:12:55 +0200 Subject: [PATCH 120/123] visualize vertical ocr text vertically --- train/generate_gt_for_training.py | 36 +++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/train/generate_gt_for_training.py b/train/generate_gt_for_training.py index d4b58dc..91ee2c8 100644 --- a/train/generate_gt_for_training.py +++ b/train/generate_gt_for_training.py @@ -514,19 +514,37 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out): #w_bb = bb_ind[2] #h_bb = bb_ind[3] if ocr_texts[index]: + + + is_vertical = h > 2*w # Check orientation font = fit_text_single_line(draw, ocr_texts[index], font_path, w, int(h*0.4) ) - ##draw.rectangle([x_bb, y_bb, x_bb + w_bb, y_bb + h_bb], outline="red", width=2) - - text_bbox = draw.textbbox((0, 0), ocr_texts[index], font=font) - text_width = text_bbox[2] - text_bbox[0] - text_height = text_bbox[3] - text_bbox[1] + if is_vertical: + + vertical_font = fit_text_single_line(draw, ocr_texts[index], font_path, h, int(w * 0.8)) - text_x = x + (w - text_width) // 2 # Center horizontally - text_y = y + (h - text_height) // 2 # Center vertically + text_img = Image.new("RGBA", (h, w), (255, 255, 255, 0)) # Note: dimensions are swapped + text_draw = ImageDraw.Draw(text_img) + text_draw.text((0, 0), ocr_texts[index], font=vertical_font, fill="black") - # Draw the text - draw.text((text_x, text_y), ocr_texts[index], fill="black", font=font) + # Rotate text image by 90 degrees + rotated_text = text_img.rotate(90, expand=1) + + # Calculate paste position (centered in bbox) + paste_x = x + (w - rotated_text.width) // 2 + paste_y = y + (h - rotated_text.height) // 2 + + image_text.paste(rotated_text, (paste_x, paste_y), rotated_text) # Use rotated image as mask + else: + text_bbox = draw.textbbox((0, 0), ocr_texts[index], font=font) + text_width = text_bbox[2] - text_bbox[0] + text_height = text_bbox[3] - text_bbox[1] + + text_x = x + (w - text_width) // 2 # Center horizontally + text_y = y + (h - text_height) // 2 # Center vertically + + # Draw the text + draw.text((text_x, text_y), ocr_texts[index], fill="black", font=font) image_text.save(os.path.join(dir_out, f_name+'.png')) if __name__ == "__main__": From 68a71be8bc77567984131dc5e16a733209bf32f2 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Sat, 13 Sep 2025 22:40:11 +0200 Subject: [PATCH 121/123] Running inference on files in a directory --- train/inference.py | 86 +++++++++++++++++++++++++++++++++------------- 1 file changed, 62 insertions(+), 24 deletions(-) diff --git a/train/inference.py b/train/inference.py index aecd0e6..094c528 100644 --- a/train/inference.py +++ b/train/inference.py @@ -28,8 +28,9 @@ Tool to load model and predict for given image. """ class sbb_predict: - def __init__(self,image, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area): + def __init__(self,image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area): self.image=image + self.dir_in=dir_in self.patches=patches self.save=save self.save_layout=save_layout @@ -223,11 +224,10 @@ class sbb_predict: return added_image, layout_only - def predict(self): - self.start_new_session_and_model() + def predict(self, image_dir): if self.task == 'classification': classes_names = self.config_params_model['classification_classes_name'] - img_1ch = img=cv2.imread(self.image, 0) + img_1ch = img=cv2.imread(image_dir, 0) img_1ch = img_1ch / 255.0 img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'], self.config_params_model['input_width']), interpolation=cv2.INTER_NEAREST) @@ -438,7 +438,7 @@ class sbb_predict: if self.patches: #def textline_contours(img,input_width,input_height,n_classes,model): - img=cv2.imread(self.image) + img=cv2.imread(image_dir) self.img_org = np.copy(img) if img.shape[0] < self.img_height: @@ -529,7 +529,7 @@ class sbb_predict: else: - img=cv2.imread(self.image) + img=cv2.imread(image_dir) self.img_org = np.copy(img) width=self.img_width @@ -557,22 +557,50 @@ class sbb_predict: def run(self): - res=self.predict() - if (self.task == 'classification' or self.task == 'reading_order'): - pass - elif self.task == 'enhancement': - if self.save: - cv2.imwrite(self.save,res) + self.start_new_session_and_model() + if self.image: + res=self.predict(image_dir = self.image) + + if (self.task == 'classification' or self.task == 'reading_order'): + pass + elif self.task == 'enhancement': + if self.save: + cv2.imwrite(self.save,res) + else: + img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task) + if self.save: + cv2.imwrite(self.save,img_seg_overlayed) + if self.save_layout: + cv2.imwrite(self.save_layout, only_layout) + + if self.ground_truth: + gt_img=cv2.imread(self.ground_truth) + self.IoU(gt_img[:,:,0],res[:,:,0]) + else: - img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task) - if self.save: - cv2.imwrite(self.save,img_seg_overlayed) - if self.save_layout: - cv2.imwrite(self.save_layout, only_layout) + ls_images = os.listdir(self.dir_in) + for ind_image in ls_images: + f_name = ind_image.split('.')[0] + image_dir = os.path.join(self.dir_in, ind_image) + res=self.predict(image_dir) - if self.ground_truth: - gt_img=cv2.imread(self.ground_truth) - self.IoU(gt_img[:,:,0],res[:,:,0]) + if (self.task == 'classification' or self.task == 'reading_order'): + pass + elif self.task == 'enhancement': + self.save = os.path.join(self.out, f_name+'.png') + cv2.imwrite(self.save,res) + else: + img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task) + self.save = os.path.join(self.out, f_name+'_overlayed.png') + cv2.imwrite(self.save,img_seg_overlayed) + self.save_layout = os.path.join(self.out, f_name+'_layout.png') + cv2.imwrite(self.save_layout, only_layout) + + if self.ground_truth: + gt_img=cv2.imread(self.ground_truth) + self.IoU(gt_img[:,:,0],res[:,:,0]) + + @click.command() @click.option( @@ -581,6 +609,12 @@ class sbb_predict: help="image filename", type=click.Path(exists=True, dir_okay=False), ) +@click.option( + "--dir_in", + "-di", + help="directory of images", + type=click.Path(exists=True, file_okay=False), +) @click.option( "--out", "-o", @@ -626,15 +660,19 @@ class sbb_predict: "-min", help="min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order.", ) -def main(image, model, patches, save, save_layout, ground_truth, xml_file, out, min_area): +def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, out, min_area): + assert image or dir_in, "Either a single image -i or a dir_in -di is required" with open(os.path.join(model,'config.json')) as f: config_params_model = json.load(f) task = config_params_model['task'] if (task != 'classification' and task != 'reading_order'): - if not save: - print("Error: You used one of segmentation or binarization task but not set -s, you need a filename to save visualized output with -s") + if image and not save: + print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s") sys.exit(1) - x=sbb_predict(image, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area) + if dir_in and not out: + print("Error: You used one of segmentation or binarization task with dir_in but not set -out") + sys.exit(1) + x=sbb_predict(image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area) x.run() if __name__=="__main__": From 530897c6c2a9455d3c7713257f15351de8732b99 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Fri, 19 Sep 2025 13:20:26 +0200 Subject: [PATCH 122/123] renaming argument names --- train/generate_gt_for_training.py | 34 +++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/train/generate_gt_for_training.py b/train/generate_gt_for_training.py index 91ee2c8..7810cd7 100644 --- a/train/generate_gt_for_training.py +++ b/train/generate_gt_for_training.py @@ -157,6 +157,7 @@ def image_enhancement(dir_imgs, dir_out_images, dir_out_labels, scales): def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, input_height, input_width, min_area_size, min_area_early): xml_files_ind = os.listdir(dir_xml) + xml_files_ind = [ind_xml for ind_xml in xml_files_ind if ind_xml.endswith('.xml')] input_height = int(input_height) input_width = int(input_width) min_area = float(min_area_size) @@ -268,14 +269,14 @@ def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, i @click.option( "--dir_out", - "-do", + "-o", help="directory where plots will be written", type=click.Path(exists=True, file_okay=False), ) @click.option( "--dir_imgs", - "-dimg", + "-di", help="directory where the overlayed plots will be written", ) def visualize_reading_order(xml_file, dir_xml, dir_out, dir_imgs): @@ -283,6 +284,7 @@ def visualize_reading_order(xml_file, dir_xml, dir_out, dir_imgs): if dir_xml: xml_files_ind = os.listdir(dir_xml) + xml_files_ind = [ind_xml for ind_xml in xml_files_ind if ind_xml.endswith('.xml')] else: xml_files_ind = [xml_file] @@ -353,6 +355,12 @@ def visualize_reading_order(xml_file, dir_xml, dir_out, dir_imgs): @main.command() +@click.option( + "--xml_file", + "-xml", + help="xml filename", + type=click.Path(exists=True, dir_okay=False), +) @click.option( "--dir_xml", "-dx", @@ -362,18 +370,24 @@ def visualize_reading_order(xml_file, dir_xml, dir_out, dir_imgs): @click.option( "--dir_out", - "-do", + "-o", help="directory where plots will be written", type=click.Path(exists=True, file_okay=False), ) @click.option( "--dir_imgs", - "-dimg", + "-di", help="directory of images where textline segmentation will be overlayed", ) -def visualize_textline_segmentation(dir_xml, dir_out, dir_imgs): - xml_files_ind = os.listdir(dir_xml) +def visualize_textline_segmentation(xml_file, dir_xml, dir_out, dir_imgs): + assert xml_file or dir_xml, "A single xml file -xml or a dir of xml files -dx is required not both of them" + if dir_xml: + xml_files_ind = os.listdir(dir_xml) + xml_files_ind = [ind_xml for ind_xml in xml_files_ind if ind_xml.endswith('.xml')] + else: + xml_files_ind = [xml_file] + for ind_xml in tqdm(xml_files_ind): indexer = 0 #print(ind_xml) @@ -408,20 +422,21 @@ def visualize_textline_segmentation(dir_xml, dir_out, dir_imgs): @click.option( "--dir_out", - "-do", + "-o", help="directory where plots will be written", type=click.Path(exists=True, file_okay=False), ) @click.option( "--dir_imgs", - "-dimg", + "-di", help="directory of images where textline segmentation will be overlayed", ) def visualize_layout_segmentation(xml_file, dir_xml, dir_out, dir_imgs): assert xml_file or dir_xml, "A single xml file -xml or a dir of xml files -dx is required not both of them" if dir_xml: xml_files_ind = os.listdir(dir_xml) + xml_files_ind = [ind_xml for ind_xml in xml_files_ind if ind_xml.endswith('.xml')] else: xml_files_ind = [xml_file] @@ -466,7 +481,7 @@ def visualize_layout_segmentation(xml_file, dir_xml, dir_out, dir_imgs): @click.option( "--dir_out", - "-do", + "-o", help="directory where plots will be written", type=click.Path(exists=True, file_okay=False), ) @@ -476,6 +491,7 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out): assert xml_file or dir_xml, "A single xml file -xml or a dir of xml files -dx is required not both of them" if dir_xml: xml_files_ind = os.listdir(dir_xml) + xml_files_ind = [ind_xml for ind_xml in xml_files_ind if ind_xml.endswith('.xml')] else: xml_files_ind = [xml_file] From a65405bead03f386cf3935df4dd58b1985cfcd21 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Mon, 22 Sep 2025 15:56:14 +0200 Subject: [PATCH 123/123] tables are visulaized within layout --- train/generate_gt_for_training.py | 2 +- train/gt_gen_utils.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/train/generate_gt_for_training.py b/train/generate_gt_for_training.py index 7810cd7..388fced 100644 --- a/train/generate_gt_for_training.py +++ b/train/generate_gt_for_training.py @@ -458,7 +458,7 @@ def visualize_layout_segmentation(xml_file, dir_xml, dir_out, dir_imgs): co_text, co_graphic, co_sep, co_img, co_table, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file) - added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], img) + added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], co_table, img) cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image) diff --git a/train/gt_gen_utils.py b/train/gt_gen_utils.py index 753b0f5..38d48ca 100644 --- a/train/gt_gen_utils.py +++ b/train/gt_gen_utils.py @@ -18,7 +18,7 @@ with warnings.catch_warnings(): warnings.simplefilter("ignore") -def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_image, co_marginal, img): +def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_image, co_marginal, co_table, img): alpha = 0.5 blank_image = np.ones( (img.shape[:]), dtype=np.uint8) * 255 @@ -30,6 +30,7 @@ def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_ col_image = (0, 100, 0) col_sep = (255, 0, 0) col_marginal = (106, 90, 205) + col_table = (0, 90, 205) if len(co_image)>0: cv2.drawContours(blank_image, co_image, -1, col_image, thickness=cv2.FILLED) # Fill the contour @@ -51,6 +52,9 @@ def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_ if len(co_marginal)>0: cv2.drawContours(blank_image, co_marginal, -1, col_marginal, thickness=cv2.FILLED) # Fill the contour + + if len(co_table)>0: + cv2.drawContours(blank_image, co_table, -1, col_table, thickness=cv2.FILLED) # Fill the contour img_final =cv2.cvtColor(blank_image, cv2.COLOR_BGR2RGB)