import os
import torch
import numpy as np
from torch.utils.data.dataset import Dataset
from cv2 import resize, cvtColor, COLOR_BGR2RGB, COLOR_RGB2GRAY
from cv2 import imread as cv_imread
[docs]class InitializeDataPredict:
"""
Initialize data for segmentation.
"""
def __init__(
self, image_path):
self.items = []
image_list = os.listdir(image_path)
self.image_path = image_path
image_list.sort()
for i in range(len(image_list)):
self.items.append(image_list[i])
[docs]class SegDatasetPredict(Dataset):
"""
Create data sample for each image containing the resized RGB image and a
single channel equivalent for segmentation.
"""
def __init__(self, data: InitializeDataPredict, transform=None):
self.transform = transform
self.init_data = data
self.items = self.init_data.items
def __len__(self):
return len(self.items)
def __getitem__(self, index):
if torch.is_tensor(index):
index = index.tolist()
image_rgb = cvtColor(resize(cv_imread(
os.path.join(self.init_data.image_path, self.items[index])),
(2048,1536)), COLOR_BGR2RGB)
image = np.expand_dims(cvtColor(image_rgb, COLOR_RGB2GRAY), axis=-1)
sample = {
'file_name': self.items[index],
'image': image
}
if self.transform:
sample = self.transform(sample)
sample.update({'rgb': image_rgb})
return sample