Source code for DiatomTrack.core.segmentation

import os
import pickle
import torch
import numpy as np
import cv2 as cv

from torchvision.transforms import transforms
from torch.utils.data.dataloader import DataLoader

from tqdm.tk import tqdm_tk as tqdm

from sys import exit

from ..model.model import UNet
from ..utils.DataUtils import SegDatasetPredict, InitializeDataPredict
from ..utils.Transforms import Normalize, ToTensor



[docs]def segmentation_setup(state_path, image_dir): """ Take the path for the states of the model and the directory of the images to be segmented. Initialize the segmentation device, model and data. Parameters ---------- state_path : str The path to the states of weighted UNet. image_dir : str The directory the to the images that shall be segmented. Returns ------- device : torch.device The initialized torch.device with cude if available. model : torch.model The model loaded to GPU to perform the segmentation. data_loader : torch.dataloader The initialized dataloader for the images. """ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = '0, 1' device = torch.device("cuda" if torch.cuda.is_available() else 'cpu') states = torch.load(state_path) model = UNet([0.2, 0.2, 0.2, 0.2]) model.load_state_dict(states['state_dict']) model.to(device) model = torch.nn.DataParallel(model) model.eval() init_data = InitializeDataPredict(image_path=image_dir) dataset = SegDatasetPredict( init_data, transforms.Compose([Normalize(), ToTensor()])) data_loader = DataLoader( dataset=dataset, batch_size=4, shuffle=False, num_workers=4, prefetch_factor=10) return device, model, data_loader
[docs]def segmentation(device, model, batch): """ Process a batch of images on the defined device and with the given model. Return two dictionaries with contour lists for singles and aggregates and the RGB images. The keys of the dictionaries are the image file names. Parameters ---------- device : torch.device The initialized torch.device with cude if available. model : torch.model The model loaded to GPU to perform the segmentation. batch : torch.dataloader batch The batch of images to be segmented containing the tensors, images and file names. Returns ------- single : dict This dictionary contains the contours that were segmented as single cells. The keys are the image file names. aggregate : dict This dictionary contains the contours that were segmented as aggregated cells. The keys are the image file names. image_rgb : array The loaded image file as array. """ image, image_rgb = batch['image'].to(device), batch['rgb'] file_names = batch['file_name'] with torch.no_grad(): out = model(image) batch_mask = out.detach().cpu().numpy() single_cnt = dict() aggregate_cnt = dict() csv_rows = [] for i, mask in enumerate(batch_mask): filter_mask = mask[2] < 0.5 background_mask = np.expand_dims( np.invert(filter_mask).astype(np.float64), axis=0) modified_mask = np.where(filter_mask, mask[:2], 0.) modified_mask = np.append(modified_mask, background_mask, axis=0) mask = np.where( modified_mask == np.amax(modified_mask, axis=0), 1, 0).astype( np.uint8) single, _ = cv.findContours( mask.transpose((1, 2, 0))[...,0], cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE) aggregate, _ = cv.findContours( mask.transpose((1, 2, 0))[...,1], cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE) single_cnt.update({str(file_names[i]): single}) aggregate_cnt.update({str(file_names[i]): aggregate}) csv_rows.append( f"{file_names[i]};{[arr.tolist() for arr in single]};{[arr.tolist() for arr in aggregate]}") return single_cnt, aggregate_cnt, image_rgb.numpy(), csv_rows
[docs]def run_segmentation( state_path, source_dir, destination_dir, name, limit, only_seg): """ Run the segmentation and extract all recognized object contours. Save contours and write all images to a video file. Parameters ---------- state_path : str Path to model states. source_dir : str Path to images. destination_dir : str Path to saving directory. name : str Name of the experiment. limit : int Number of images to segment. only_seg : bool Determines program exit after segmentation. Returns ------- contours : dict The dictionary containing all found object contours of all segmented images. video_name : str Name of the created video from images. Notes ----- The limit value is divided by the batch size thus rounding to the next larger multiple for the actual number of images. """ device, model, data = segmentation_setup(state_path, source_dir) limit = int(limit/4) if limit != 0 else len(data) frame_gen = lambda: (f for f in range(len(data)*4)) gen = frame_gen() contours = {'single': dict(), 'aggregate': dict()} video_name = os.path.join(destination_dir,name)+'.avi' out = cv.VideoWriter( video_name, cv.VideoWriter_fourcc(*'XVID'), 20, (2048, 1536)) with open( os.path.join(destination_dir,name)+'.csv', 'a', newline='') as csvfile: csvfile.write("image;single contours;aggregate contours\n") l = 0 for batch in tqdm(data, desc='Segmenting image batches', total=limit): if l > limit: break single, aggregate, images, csv_rows = segmentation( device, model, batch) contours['single'].update(single) contours['aggregate'].update(aggregate) for i, img in enumerate(images): text = str(next(gen)) cv.putText( img, text, (1900, 30), cv.FONT_HERSHEY_SIMPLEX, 1, color=(255,)*3, thickness=2) out.write(cv.cvtColor(img, cv.COLOR_RGB2BGR)) for row in csv_rows: csvfile.write(row+"\n") l += 1 #out.release() if only_seg: with open( os.path.join(destination_dir, name)+'_contours.pkl', 'wb') as f: pickle.dump(contours, f) exit() else: return contours, video_name