import glob
import os
import random

import albumentations as A
import cv2 as cv
import numpy as np

from tqdm import tqdm


def vignette(image):
    """
    Apply a vignette to an image.

    Parameters
    ----------
    image : array
        The image.

    Returns
    -------
    output : array
        The vignetted image.

    """
    
    rows, cols = image.shape[:2]
    X_kernel = cv.getGaussianKernel(cols,1.5*cols)
    Y_kernel = cv.getGaussianKernel(rows,1.5*rows)
    kernel = Y_kernel * X_kernel.T
    mask = random.uniform(0.6, 0.9) * kernel / np.amax(kernel)
    output = np.copy(image)
    for i in range(3):
        output[:,:,i] = output[:,:,i] * mask
        
    return output


def main():
    """
    Apply albumentations to the training images.
    """
    
    # Set directories
    image_dir = r'PATH\TO\IMAGES\*.png'
    mask1_dir = r'PATH\TO\FINE\MASKS'
    mask2_dir = r'PATH\TO\COARSE\MASKS'
    mask3_dir = r'PATH\TO\FINE\MASKS\SINGLE'
    mask4_dir = r'PATH\TO\COARSE\MASKS\SINGLE'
    mask5_dir = r'PATH\TO\FINE\MASKS\AGGREGATE'
    mask6_dir = r'PATH\TO\COARSE\MASKS\AGGREGATE'
    
    mask_dir = [
        mask1_dir, mask2_dir, mask3_dir, mask4_dir, mask5_dir, mask6_dir]
    
    # Define random seeds for consistency
    random.seed(0)
    np.random.seed(0)
    
    # Define transformations
    transform = A.Compose([
        A.HorizontalFlip(p=0.6),
        A.RandomBrightnessContrast(p=0.5),
        A.MotionBlur(p=0.3),
        A.RandomGamma(p=0.5),
        A.CLAHE(p=0.5),
        A.RGBShift(p=0.3)
        
    ])
    
    # Process
    filenames = []
    for filename in glob.glob(image_dir):
        filenames.append(filename)
        
    for filename in tqdm(filenames):
        image = cv.imread(filename)
        mask = []
        for mask_d in mask_dir:
            mask.append(cv.imread(os.path.join(
                mask_d, os.path.basename(filename))))
        transforms = [transform(image=image, masks=mask) for n in range(5)]
        for i,t in enumerate(transforms):
            cv.imwrite(
                os.path.join(
                    os.path.dirname(filename),
                    os.path.splitext(
                        os.path.basename(filename))[0]+f'_{i}.png'),
                t['image']
            )
            for n, mask_d in enumerate(mask_dir):
                cv.imwrite(os.path.join(
                        mask_d, 
                        os.path.splitext(
                            os.path.basename(filename))[0]+f'_{i}.png'),
                    t['masks'][n]
                )
    
    
if __name__ == "__main__":
    main()