Source code for DiatomTrack.model.model

import torch

import torch.nn as nn
import torch.nn.functional as F


[docs]class UNet(nn.Module): def __init__(self, weights, n_class=3): super(UNet, self).__init__() self.weight1 = nn.Parameter(torch.tensor(weights[0])) self.weight2 = nn.Parameter(torch.tensor(weights[1])) self.weight3 = nn.Parameter(torch.tensor(weights[2])) self.weight4 = nn.Parameter(torch.tensor(weights[3])) self.level1_conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1) self.level1_conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) self.level1_maxpool = nn.MaxPool2d(kernel_size=2) self.level1_bat1 = nn.BatchNorm2d(num_features=64) self.level1_bat2 = nn.BatchNorm2d(num_features=64) self.level2_conv1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) self.level2_conv2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) self.level2_maxpool = nn.MaxPool2d(kernel_size=2) self.level2_bat1 = nn.BatchNorm2d(num_features=128) self.level2_bat2 = nn.BatchNorm2d(num_features=128) self.level3_conv1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) self.level3_conv2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) self.level3_maxpool = nn.MaxPool2d(kernel_size=2) self.level3_bat1 = nn.BatchNorm2d(num_features=256) self.level3_bat2 = nn.BatchNorm2d(num_features=256) self.level4_conv1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) self.level4_conv2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) self.level4_maxpool = nn.MaxPool2d(kernel_size=2) self.level4_bat1 = nn.BatchNorm2d(num_features=512) self.level4_bat2 = nn.BatchNorm2d(num_features=512) self.level5_conv1 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1) self.level5_conv2 = nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1) self.level5_bat1 = nn.BatchNorm2d(num_features=1024) self.level5_bat2 = nn.BatchNorm2d(num_features=1024) self.level5_upsample = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=4, stride=2, padding=1) self.level4_conv3 = nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1) self.level4_conv4 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) self.level4_bat3 = nn.BatchNorm2d(num_features=512) self.level4_bat4 = nn.BatchNorm2d(num_features=512) self.level4_upsample = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1) self.level3_conv3 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1) self.level3_conv4 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) self.level3_bat3 = nn.BatchNorm2d(num_features=256) self.level3_bat4 = nn.BatchNorm2d(num_features=256) self.level3_upsample = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1) self.level2_conv3 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1) self.level2_conv4 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) self.level2_bat3 = nn.BatchNorm2d(num_features=128) self.level2_bat4 = nn.BatchNorm2d(num_features=128) self.level2_upsample = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1) self.level1_conv3 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1) self.level1_conv4 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) self.level1_conv5 = nn.Conv2d(64, n_class, kernel_size=3, stride=1, padding=1) self.level1_bat3 = nn.BatchNorm2d(num_features=64) self.level1_bat4 = nn.BatchNorm2d(num_features=64)
[docs] def forward(self, x): self.feature1 = F.relu(self.level1_bat2( self.level1_conv2(F.relu(self.level1_bat1(self.level1_conv1(x)))))) self.feature2 = F.relu(self.level2_bat2( self.level2_conv2(F.relu( self.level2_bat1(self.level2_conv1(self.level1_maxpool(self.feature1))))))) self.feature3 = F.relu(self.level3_bat2( self.level3_conv2(F.relu( self.level3_bat1(self.level3_conv1(self.level2_maxpool(self.feature2))))))) self.feature4 = F.relu(self.level4_bat2( self.level4_conv2(F.relu( self.level4_bat1(self.level4_conv1(self.level3_maxpool(self.feature3))))))) self.feature5 = self.level5_bat2( self.level5_conv2( self.level5_bat1(self.level5_conv1(self.level4_maxpool(self.feature4))))) self.feature4_up = torch.cat( (self.weight1 * self.level5_upsample(self.feature5), (1 - self.weight1) * self.feature4), dim=1) self.feature4_2 = F.relu( self.level4_bat4(self.level4_conv4( F.relu(self.level4_bat3(self.level4_conv3(self.feature4_up)))))) self.feature3_up = torch.cat( (self.weight2 * self.level4_upsample(self.feature4_2), (1 - self.weight2) * self.feature3), dim=1) self.feature3_2 = F.relu( self.level3_bat4(self.level3_conv4( F.relu(self.level3_bat3(self.level3_conv3(self.feature3_up)))))) self.feature2_up = torch.cat( (self.weight3 * self.level3_upsample(self.feature3_2), (1 - self.weight3) * self.feature2), dim=1) self.feature2_2 = F.relu( self.level2_bat4(self.level2_conv4( F.relu(self.level2_bat3(self.level2_conv3(self.feature2_up)))))) self.feature1_up = torch.cat( (self.weight4 * self.level2_upsample(self.feature2_2), (1 - self.weight4) * self.feature1), dim=1) self.feature1_2 = F.relu( self.level1_bat4(self.level1_conv4( F.relu(self.level1_bat3(self.level1_conv3(self.feature1_up)))))) # additional softmax self.out = F.softmax(self.level1_conv5(self.feature1_2), dim=1) return self.out