Source code for DiatomTrack.utils.DataUtils

import os
import torch

import numpy as np

from torch.utils.data.dataset import Dataset
from cv2 import resize, cvtColor, COLOR_BGR2RGB, COLOR_RGB2GRAY

from cv2 import imread as cv_imread

    
    
[docs]class InitializeDataPredict: """ Initialize data for segmentation. """ def __init__( self, image_path): self.items = [] image_list = os.listdir(image_path) self.image_path = image_path image_list.sort() for i in range(len(image_list)): self.items.append(image_list[i])
[docs]class SegDatasetPredict(Dataset): """ Create data sample for each image containing the resized RGB image and a single channel equivalent for segmentation. """ def __init__(self, data: InitializeDataPredict, transform=None): self.transform = transform self.init_data = data self.items = self.init_data.items def __len__(self): return len(self.items) def __getitem__(self, index): if torch.is_tensor(index): index = index.tolist() image_rgb = cvtColor(resize(cv_imread( os.path.join(self.init_data.image_path, self.items[index])), (2048,1536)), COLOR_BGR2RGB) image = np.expand_dims(cvtColor(image_rgb, COLOR_RGB2GRAY), axis=-1) sample = { 'file_name': self.items[index], 'image': image } if self.transform: sample = self.transform(sample) sample.update({'rgb': image_rgb}) return sample