import os
import sys
import traceback

import numpy as np
import cv2 as cv

import utils 

from diplib import MeasurementTool as dpMT


class TrainingImageGenerator:
    """
    Create fine and coarse label masks from an image using blurring, 
    contrasting, background subtraction and thresholding.
    """
    def __init__(
            self, file_path, ballradius = 5, blurring_kernel = (3,3), 
            erosion_kernel_size = 3, contour_size_conformity = 0.8, 
            expected_size = 700, saving_directory = None):
        """
        

        Parameters
        ----------
        file_path : str
            Path to image.
        ballradius : int, optional
            The rolling ball radius for background subtraction. 
            The default is 5.
        blurring_kernel : tuple, optional
            Integer tuple for the kernel size. The default is (3,3).
        erosion_kernel_size : int, optional
            Erosion kernel dimension. The default is 3.
        contour_size_conformity : float, optional
            Acceptable contour size range of expectation. The default is 0.8.
        expected_size : int, optional
            The expected object area. The default is 700.
        saving_directory : str, optional
            The saving directory. The default is None. This equates to the 
            local directory.

        Returns
        -------
        None.

        """
        
        # User input
        self.file = file_path
        self.image = cv.imread(file_path)
        self.ballradius = ballradius
        self.blurring_kernel = blurring_kernel
        self.erosion_kernel_size = erosion_kernel_size
        self.contour_size_conformity = contour_size_conformity
        self.expected_size = expected_size
        self.save_dir = saving_directory
        
        # Filled by methods
        self.filename = None
        self.contours = None
        self.hierarchy = None
        self.contours_overlay = None        
        self.fine_mask_objects = None
        self.coarse_mask_objects = None
        self.fine_mask_aggregates = None
        self.coarse_mask_aggregates = None
        self.image_clahe = None
        
        
    def _size_rating(self, contours):
        """
        Sort contours by size into inside and outside expected range.

        Parameters
        ----------
        contours : list
            List of contours.

        Returns
        -------
        contours_in_range : list
            Contours that are inside the size range.
        contours_over_range : list
            Countours that are outside the size range.

        """
        contours_in_range = []
        contours_over_range = []
        for contour in contours:
            area = cv.contourArea(contour)
            (x,y), radius = cv.minEnclosingCircle(contour)
            circle_area = np.pi*radius**2
            if (
                self.expected_size - 
                self.contour_size_conformity * self.expected_size 
                < area < 
                self.expected_size + 
                self.contour_size_conformity * self.expected_size 
                and
                area < 0.6 * circle_area
                ):
                contours_in_range.append(contour)
            elif (
                area > 
                self.expected_size + 
                self.contour_size_conformity * self.expected_size 
                or
                self.expected_size - 
                self.contour_size_conformity * self.expected_size 
                < area < 
                self.expected_size + 
                self.contour_size_conformity * self.expected_size
                ):
                contours_over_range.append(contour)
                
        return contours_in_range, contours_over_range
        

    def _file_name(self):
        """
        Return the input file name from the input file location.
        """
        
        return self.file[::-1].split(".")[1].split(
            os.sep.split(self.file)[0])[0][::-1]
    
    
    def _contour_extraction(self):
        """
        Extract countours using cv2's extraction function. The image is first 
        preprocessed to remove background and enhance objects.

        Returns
        -------
        contours : list
            Contours found.
        hierarchy : list
            Hierarchy of the contours.

        """
        
        clahe = cv.createCLAHE(2,(8,8))
        self.image_clahe = np.asarray(
            [clahe.apply(c) for c in self.image.copy()])
        image_contrasted = cv.GaussianBlur(
            cv.convertScaleAbs(self.image_clahe, alpha=1.5, beta=0), (5,5), 3) 
        image_blurred = cv.blur(image_contrasted, self.blurring_kernel)
        image_bg_subtracted = utils.subtract_background(
            image_blurred, self.ballradius)
        image_gblurred = cv.GaussianBlur(
            image_bg_subtracted, self.blurring_kernel, 0)
        image_gblurred_gray = cv.cvtColor(image_gblurred, cv.COLOR_BGR2GRAY)
        ret, image_thresholded = cv.threshold(
            image_gblurred_gray,0,255,cv.THRESH_BINARY+cv.THRESH_OTSU)
        contours, hierarchy = cv.findContours(
            image_thresholded, cv.RETR_TREE, cv.CHAIN_APPROX_NONE)
        
        return contours, hierarchy

        
    def _create_mask(self, contours, hierarchy):
        """
        Create masks from the found contours after sorting and post-processing.

        Parameters
        ----------
        contours : list
            The contours.
        hierarchy : list
            The contour hierarchy.

        Returns
        -------
        None.

        """
        
        # Find contours on second level
        new_inner_contours = [
            contours[i] for i in range(len(contours)) 
            if hierarchy[0][i][3]>=0 and 
            hierarchy[0][hierarchy[0][i][3]][3]==-1]
        # Merge proximal second level contours 
        new_contours, new_hierarchy = cv.findContours(cv.drawContours(
                np.zeros(self.image.shape[:2], np.uint8), new_inner_contours, 
                -1, (255), thickness=2), cv.RETR_TREE, cv.CHAIN_APPROX_NONE)
        new_outer_contours = [
            new_contours[i] for i in range(len(new_contours)) 
            if new_hierarchy[0][i][3]==-1]
        closed_contours = utils.contour_closing(
            new_outer_contours, self.expected_size)
        
        # Rate contours by size into single cell objects or potential aggregate
        object_contours_1, aggregate_contours_1 = self._size_rating(
            closed_contours)
          
        # Find remaining (connected) convex hulls of outer contours
        outer_contours_temporary = [
            cv.convexHull(contours[i]) 
            for i in range(len(contours)) 
            if hierarchy[0][i][3]==-1]
        outer_contours_temporary_mask = cv.drawContours(
            np.zeros(self.image.shape[:2], np.uint8), 
            outer_contours_temporary, -1, (1), thickness=-1)
        outer_contours, temporary_hierarchy = cv.findContours(
            outer_contours_temporary_mask, cv.RETR_EXTERNAL, 
            cv.CHAIN_APPROX_NONE)
        mask_outer_contours_filled = cv.drawContours(
            np.zeros(self.image.shape[:2], np.uint8), 
            outer_contours, -1, (1), thickness=-1)
        mask_outer_contours_line = cv.drawContours(
            np.zeros(self.image.shape[:2], np.uint8), 
            outer_contours, -1, (1), thickness=1)
        mask_rated_closed_contours = cv.drawContours(
            np.zeros(self.image.shape[:2], np.uint8), 
            object_contours_1 + aggregate_contours_1, -1, (1), 
            thickness=-1)
        masks_sum = (
            (mask_outer_contours_filled - mask_rated_closed_contours) + 
            mask_outer_contours_line)
        mask_sum_filtered = np.where(
            masks_sum < 1, np.zeros(masks_sum.shape, np.uint8), 1)
        contours_from_mask, hierarchy_from_mask = cv.findContours(
            mask_sum_filtered, cv.RETR_TREE, cv.CHAIN_APPROX_NONE)
        filtered_contours_from_mask = [
            contours_from_mask[i] for i in range(len(contours_from_mask))
            if hierarchy_from_mask[0][i][3]==-1 and 
            hierarchy_from_mask[0][i][2]==-1]
        
        # Create a mask that labels each of the contours
        filled_contours_from_mask = utils.contour_pixels(
            filtered_contours_from_mask)
        mask_from_filled_contours = np.zeros(self.image.shape[:2])
        for idx, liste in enumerate(filled_contours_from_mask):
            for pixel in liste:
                mask_from_filled_contours[pixel[1],pixel[0]]=idx+1
       
        # Measure each contours area and minimum and maximum intensity
        measurement_results = dpMT.Measure(
            mask_from_filled_contours.astype(np.uint32), 
            cv.cvtColor(self.image_clahe.copy(), cv.COLOR_BGR2GRAY), 
            ["MinVal", "MaxVal", "Size"])
        minimums = []
        try: 
            # Frequent RuntimeErrors occur for some images for unclear reasons 
            # in diplib
            for i in range(int(np.amax(mask_from_filled_contours))):
                minimums.append(measurement_results[i+1]["MinVal"][0])
        except Exception:
            traceback.print_exception(*sys.exc_info())
        maximums = []
        for i in range(int(np.amax(mask_from_filled_contours))):
            maximums.append(measurement_results[i+1]["MaxVal"][0])
        sizes = []
        for i in range(int(np.amax(mask_from_filled_contours))):
            sizes.append(measurement_results[i+1]["Size"][0])
        
        # Filter contours based on area and contrast / consistent intensity
        delete_indeces = set()
        for idx, values in enumerate(zip(minimums, maximums, sizes)):
            minimum, maximum, size = values
            if (
                maximum - minimum < 180 or size < self.expected_size):
                delete_indeces.add(idx)   
        filtered_outer_contours = [
            contour for i, contour in enumerate(filtered_contours_from_mask) 
            if i not in delete_indeces]
        object_contours_2, aggregate_contours_2 = self._size_rating(
            filtered_outer_contours)
        
        # Combine lists for each label
        object_contour_list = object_contours_1 + object_contours_2
        aggregate_contour_list = aggregate_contours_1 + aggregate_contours_2
        
        # Create the masks
        self.fine_mask_objects = cv.drawContours(
            np.zeros(self.image.shape[:2], np.uint8), 
            object_contour_list, -1, (1), thickness=-1)
        self.fine_mask_aggregates = cv.drawContours(
            np.zeros(self.image.shape[:2], np.uint8), 
            aggregate_contour_list, -1, (1), thickness=-1)
        self.coarse_mask_objects = utils.convex_masking(
            self.fine_mask_objects)
        self.coarse_mask_aggregates = utils.convex_masking(
            self.fine_mask_aggregates)
        self.contours_overlay = cv.drawContours(
            cv.drawContours(
                self.image.copy(), object_contour_list, -1, (0,0,255)), 
            aggregate_contour_list, -1, (255,0,0))

    
    def run(self):
        """
        Main function of the class. Create or check directories, extract
        contours, create masks and save them.

        Returns
        -------
        None.

        """
        
        saving_folders = [
            'training_data',
            'training_data\\images',
            'training_data\\fine_mask_single',
            'training_data\\coarse_mask_single',
            'training_data\\fine_mask_aggregate',
            'training_data\\coarse_mask_aggregate',
            'training_data\\fine_masks',
            'training_data\\coarse_masks'
            ]
        current_dir = os.getcwd()
        existence = True
        for i in saving_folders:
            if self.save_dir is not None:
                saving_directory = os.path.join(self.save_dir, i)
            else:
                saving_directory = os.path.join(current_dir, i)
            if not os.path.exists(saving_directory):
                os.makedirs(saving_directory)
                existence = False
        if not existence:
            print('Created directory...')

        self.filename = self._file_name()
        if self.fine_mask_objects is None:
            # get the masks    
            contours, hierarchy = self._contour_extraction()
            try:
                self._create_mask(contours, hierarchy)
            except ValueError:
                return
        final_data = [
            self.fine_mask_objects, self.coarse_mask_objects, 
            self.fine_mask_aggregates, self.coarse_mask_aggregates, 
            self.image, self.fine_mask_objects+self.fine_mask_aggregates,
            self.coarse_mask_objects+self.coarse_mask_aggregates]
        directories = [
            ('training_data\\fine_mask_single', f'{self.filename}'), 
            ('training_data\\coarse_mask_single', f'{self.filename}'),
            ('training_data\\fine_mask_aggregate', f'{self.filename}'), 
            ('training_data\\coarse_mask_aggregate', f'{self.filename}'),
            ('training_data\\images', f'{self.filename}'),
            ('training_data\\fine_masks', f'{self.filename}'),
            ('training_data\\coarse_masks', f'{self.filename}')]
        for data, directory in zip(final_data, directories):
            if self.save_dir is not None:
                path = os.path.join(self.save_dir,directory[0],directory[1])
            else:   
                path = os.path.join(directory[0],directory[1])
            cv.imwrite(path+'.png', cv.resize(data, (2048,1536)))

        
                
        