import os
import pickle
import torch
import numpy as np
import cv2 as cv
from torchvision.transforms import transforms
from torch.utils.data.dataloader import DataLoader
from tqdm.tk import tqdm_tk as tqdm
from sys import exit
from ..model.model import UNet
from ..utils.DataUtils import SegDatasetPredict, InitializeDataPredict
from ..utils.Transforms import Normalize, ToTensor
[docs]def segmentation_setup(state_path, image_dir):
"""
Take the path for the states of the model and the directory of the images
to be segmented. Initialize the segmentation device, model and data.
Parameters
----------
state_path : str
The path to the states of weighted UNet.
image_dir : str
The directory the to the images that shall be segmented.
Returns
-------
device : torch.device
The initialized torch.device with cude if available.
model : torch.model
The model loaded to GPU to perform the segmentation.
data_loader : torch.dataloader
The initialized dataloader for the images.
"""
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '0, 1'
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
states = torch.load(state_path)
model = UNet([0.2, 0.2, 0.2, 0.2])
model.load_state_dict(states['state_dict'])
model.to(device)
model = torch.nn.DataParallel(model)
model.eval()
init_data = InitializeDataPredict(image_path=image_dir)
dataset = SegDatasetPredict(
init_data, transforms.Compose([Normalize(), ToTensor()]))
data_loader = DataLoader(
dataset=dataset, batch_size=4, shuffle=False, num_workers=4,
prefetch_factor=10)
return device, model, data_loader
[docs]def segmentation(device, model, batch):
"""
Process a batch of images on the defined device and with the given model.
Return two dictionaries with contour lists for singles and aggregates and
the RGB images. The keys of the dictionaries are the image file names.
Parameters
----------
device : torch.device
The initialized torch.device with cude if available.
model : torch.model
The model loaded to GPU to perform the segmentation.
batch : torch.dataloader batch
The batch of images to be segmented containing the tensors, images and
file names.
Returns
-------
single : dict
This dictionary contains the contours that were segmented as single
cells. The keys are the image file names.
aggregate : dict
This dictionary contains the contours that were segmented as aggregated
cells. The keys are the image file names.
image_rgb : array
The loaded image file as array.
"""
image, image_rgb = batch['image'].to(device), batch['rgb']
file_names = batch['file_name']
with torch.no_grad():
out = model(image)
batch_mask = out.detach().cpu().numpy()
single_cnt = dict()
aggregate_cnt = dict()
csv_rows = []
for i, mask in enumerate(batch_mask):
filter_mask = mask[2] < 0.5
background_mask = np.expand_dims(
np.invert(filter_mask).astype(np.float64), axis=0)
modified_mask = np.where(filter_mask, mask[:2], 0.)
modified_mask = np.append(modified_mask, background_mask, axis=0)
mask = np.where(
modified_mask == np.amax(modified_mask, axis=0), 1, 0).astype(
np.uint8)
single, _ = cv.findContours(
mask.transpose((1, 2, 0))[...,0], cv.RETR_EXTERNAL,
cv.CHAIN_APPROX_SIMPLE)
aggregate, _ = cv.findContours(
mask.transpose((1, 2, 0))[...,1], cv.RETR_EXTERNAL,
cv.CHAIN_APPROX_SIMPLE)
single_cnt.update({str(file_names[i]): single})
aggregate_cnt.update({str(file_names[i]): aggregate})
csv_rows.append(
f"{file_names[i]};{[arr.tolist() for arr in single]};{[arr.tolist() for arr in aggregate]}")
return single_cnt, aggregate_cnt, image_rgb.numpy(), csv_rows
[docs]def run_segmentation(
state_path, source_dir, destination_dir, name, limit, only_seg):
"""
Run the segmentation and extract all recognized object contours. Save
contours and write all images to a video file.
Parameters
----------
state_path : str
Path to model states.
source_dir : str
Path to images.
destination_dir : str
Path to saving directory.
name : str
Name of the experiment.
limit : int
Number of images to segment.
only_seg : bool
Determines program exit after segmentation.
Returns
-------
contours : dict
The dictionary containing all found object contours of all segmented
images.
video_name : str
Name of the created video from images.
Notes
-----
The limit value is divided by the batch size thus rounding to the next
larger multiple for the actual number of images.
"""
device, model, data = segmentation_setup(state_path, source_dir)
limit = int(limit/4) if limit != 0 else len(data)
frame_gen = lambda: (f for f in range(len(data)*4))
gen = frame_gen()
contours = {'single': dict(), 'aggregate': dict()}
video_name = os.path.join(destination_dir,name)+'.avi'
out = cv.VideoWriter(
video_name, cv.VideoWriter_fourcc(*'XVID'), 20, (2048, 1536))
with open(
os.path.join(destination_dir,name)+'.csv', 'a', newline='') as csvfile:
csvfile.write("image;single contours;aggregate contours\n")
l = 0
for batch in tqdm(data, desc='Segmenting image batches', total=limit):
if l > limit:
break
single, aggregate, images, csv_rows = segmentation(
device, model, batch)
contours['single'].update(single)
contours['aggregate'].update(aggregate)
for i, img in enumerate(images):
text = str(next(gen))
cv.putText(
img, text, (1900, 30), cv.FONT_HERSHEY_SIMPLEX, 1,
color=(255,)*3, thickness=2)
out.write(cv.cvtColor(img, cv.COLOR_RGB2BGR))
for row in csv_rows:
csvfile.write(row+"\n")
l += 1
#out.release()
if only_seg:
with open(
os.path.join(destination_dir, name)+'_contours.pkl', 'wb') as f:
pickle.dump(contours, f)
exit()
else:
return contours, video_name