import numpy as np
import cv2 as cv

from itertools import repeat
from multiprocessing import Pool
from operator import itemgetter
from skimage import restoration


def contour_pixels(contours):
    """
    Get all pixel positions within a contour.

    Parameters
    ----------
    contours : list
        The contours.

    Returns
    -------
    contour_pixels : list
        The pixels within each contour.

    """
    
    contour_pixels = []
    for contour in contours:
        flat_contour = np.reshape(contour,(len(contour),2))
        pixel_sorted=sorted(
            sorted(flat_contour, key=itemgetter(1)), key=itemgetter(0))
        all_pixels = []
        for i in range(len(pixel_sorted)):
            all_pixels.append(pixel_sorted[i])
            if i < len(pixel_sorted)-1:
                if (
                    pixel_sorted[i+1][1] == pixel_sorted[i][1]+1 and 
                    pixel_sorted[i+1][0] == pixel_sorted[i][0]):
                    continue
                elif pixel_sorted[i+1][0] == pixel_sorted[i][0]:
                    for n in range(pixel_sorted[i][1],pixel_sorted[i+1][1]):
                        all_pixels.append(np.asarray([pixel_sorted[i][0],n]))
                else:
                    continue
        contour_pixels.append(all_pixels)
        
    return contour_pixels


def subtract_background(image, radius):
    """
    Subtract the image background in each channel with the rolling ball 
    algorithm.

    Parameters
    ----------
    image : array
        The image.
    radius : int
        The ball radius / kernel size.

    Returns
    -------
    array
        The image - the background.

    """
    
    background = np.empty(image.shape).astype(np.uint8)
    with Pool() as pool:
        result = pool.map(
            rolling_ball, zip(
                [image[...,i] for i in range(image.shape[2])], 
                repeat(radius)))
    for i in range(image.shape[2]):
        background[...,i] = result[i]
        
    return image - background


def rolling_ball(args):
    """
    Calculate the background for one channel with rolling ball algorithm.

    Parameters
    ----------
    args : tuple
        Image and radius.

    Returns
    -------
    array
        The background.

    """
    
    channel, radius = args
    
    return restoration.rolling_ball(channel, radius=radius)


def convex_masking(mask):
    """
    Create a coarse mask from a fine mask by reducing the arc length of the 
    contour and creating the convex hull.

    Parameters
    ----------
    mask : array
        The fine mask.

    Returns
    -------
    mask : array
        The coarse mask.

    """
    
    contours, hierarchy = cv.findContours(
        mask, cv.RETR_LIST, cv.CHAIN_APPROX_NONE)
    convex_contours = [
        cv.approxPolyDP(
            contour, 0.01 * cv.arcLength(contour, True), True) 
        for contour in contours]
    mask = cv.drawContours(
        np.zeros(mask.shape , np.uint8), 
        convex_contours, -1, (1), -1)
    contours, hierarchy = cv.findContours(
        mask, cv.RETR_LIST, cv.CHAIN_APPROX_NONE)
    convex_contours = [
        cv.convexHull(contour) for contour in contours]
    mask = cv.drawContours(
        np.zeros(mask.shape , np.uint8), 
        convex_contours, -1, (1), -1)    

    return mask

    
def contour_closing(contours, expected_size):
    """
    Clean longer contours by removing strongly convex or concave parts, giving 
    them a 'closed' outline.

    Parameters
    ----------
    contours : list
        The contours.
    expected_size : int
        The expected area of an object.

    Returns
    -------
    tuple
        The closed contours.

    """
    
    closed_contours = []
    for contour in contours:
        if len(contour) > 80:
            convex_area = cv.contourArea(cv.convexHull(contour))
            if convex_area < 1.5*expected_size:
                area = cv.contourArea(contour)
                hull = np.sort(
                    cv.convexHull(contour, returnPoints=False), axis=0)
                contour_pieces = [
                    contour[hull[i,0]:hull[i+1,0]] for i in range(len(hull)-1)]
                contour_pieces.append(
                    contour[list(range(hull[-1,0], len(contour)))+[0]])
                dist = [
                    np.sqrt(
                        (pc[0][0][0]-pc[-1][0][0])**2 + 
                        (pc[0][0][1]-pc[-1][0][1])**2
                    ) 
                    for pc in contour_pieces]
                areas = [cv.contourArea(cnt) for cnt in contour_pieces]
                filtered_pieces = [
                    piece for piece, d, a in zip(contour_pieces, dist, areas) 
                    if not(a >= (d/4)**2 and a >= 0.005*area)]
                new_contour=np.concatenate(filtered_pieces, axis=0)
                
                closed_contours.append(new_contour)
            else:
                closed_contours.append(contour)        
        else:
            closed_contours.append(contour)
    
    return tuple(closed_contours)