【Backgroud】
Sabour S, Frosst N, Hinton G E. Dynamic Routing Between Capsules[J]. NIPS2017
【环境】
-
CUDA: 9.2.148
-
Torch:1.2.0
-
OS: Ubuntu 16.04
-
HW: Nvidia Tesla P100 / Nvidia GTX 1080Ti / Nvidia RTX 2080Ti
【关于CapsNet的介绍】
见文:【CapsNet】Dynamic Routing Between Capsules
【从模型说起】
Conv1
这一步就是一个常规的卷积操作,用了 256 个 stride 为 1 的 9×9 的 filter,得到一个 20x20x256 的输出。按照原文的意思,这一步主要作用就是对图像像素做一次局部特征检测。让我们 Conv1 层的维度是如何得到的。(但为什么不一开始就用 Capsule 呢?因为 Capsule 是用来表征某个物体的“实例”,因此它更适合于表征高级的实例。如果直接用 Capsule 吸取图片的低级特征内容,不是很理想,而 CNN 却擅长抽取低级特征,因此一开始用 CNN 是合理的)
整个CapsuleNet代码实现如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 | class CapsuleNet(nn.Module): """Basic implementation of capsule network layer.""" def __init__( self ): super (CapsuleNet, self ).__init__() # Conv2d layer self .conv = nn.Conv2d( 1 , 256 , 9 ) self .relu = nn.ReLU(inplace = True ) # Primary capsule self .primary_caps = PrimaryCaps(num_conv_units = 32 , in_channels = 256 , out_channels = 8 , kernel_size = 9 , stride = 2 ) # Digit capsule self .digit_caps = DigitCaps(in_dim = 8 , in_caps = 32 * 6 * 6 , out_caps = 10 , out_dim = 16 , num_routing = 3 ) # Reconstruction layer self .decoder = nn.Sequential( nn.Linear( 16 * 10 , 512 ), nn.ReLU(inplace = True ), nn.Linear( 512 , 1024 ), nn.ReLU(inplace = True ), nn.Linear( 1024 , 784 ), nn.Sigmoid()) def forward( self , x): out = self .relu( self .conv(x)) out = self .primary_caps(out) out = self .digit_caps(out) # Shape of logits: (batch_size, out_capsules) logits = torch.norm(out, dim = - 1 ) pred = torch.eye( 10 ).to(device).index_select(dim = 0 , index = torch.argmax(logits, dim = 1 )) # Reconstruction batch_size = out.shape[ 0 ] reconstruction = self .decoder((out * pred.unsqueeze( 2 )).contiguous().view(batch_size, - 1 )) return logits, reconstruction |
PrimaryCaps:
Conv2 层才是开始含有 Capsule。如果按照普通 CNN 里面的做法,用了 32 个 stride 为 2 的 9x9x256 的 filter,也只能得到 6x6x32 的输出,算法如下:
但是从上图和 Hinton 的论文发现,Conv2 层的维度是 6x6x8x32。这个 8怎么来的?它又代表着什么含义?个人理解是用 32 个 stride 为 2 的 9x9x256 的filter做了 8次卷积操作,而且
-
在 CNN 中,维度为 6x6x1x32 的层里有 6x6x32 元素,每个元素是一个标量
-
在 Capsule 中,维度为 6x6x8x32 的层里有 6x6x32 元素,每个元素是一个 1×8的向量,既 capsule
Conv2 层的输出在论文中称为 Primary Capsule,简称 PrimaryCaps,主要储存低级别特征的向量。
Primary Capsule模型定义代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | class PrimaryCaps(nn.Module): """Primary capsule layer.""" def __init__( self , num_conv_units, in_channels, out_channels, kernel_size, stride): super (PrimaryCaps, self ).__init__() # Each conv unit stands for a single capsule. self .conv = nn.Conv2d(in_channels = in_channels, out_channels = out_channels * num_conv_units, kernel_size = kernel_size, stride = stride) self .out_channels = out_channels def forward( self , x): # Shape of x: (batch_size, in_channels, height, weight) # Shape of out: out_capsules * (batch_size, out_channels, height, weight) out = self .conv(x) # Flatten out: (batch_size, out_capsules * height * weight, out_channels) batch_size = out.shape[ 0 ] return squash(out.contiguous().view(batch_size, - 1 , self .out_channels), dim = - 1 ) |
DigitCaps:
下一层就是存储高级别特征的向量,在本例中就是数字,FC 层的输出在论文中称为 Digit Capsule,简称 DigitCaps。PrimaryCaps 和 DigitCaps 是全连接的,但不是像传统神经网络标量和标量相连,而是向量与向量相连。
PrimaryCaps 里面有 6x6x32 元素,每个元素是一个 1×8的向量,而 DigitCaps 有 10 个元素 (因为有 10 个数字),每个元素是一个 1×16的向量。为了让 1×8向量与 1×16向量全连接,需要 6x6x32 个 8×16的矩阵
现在 PrimaryCaps 有 6x6x32 = 1152 个 VN,而 DigitCaps 有 10 个 VN,那么 I = 1,2, …, 1152, j = 0,1, …, 9。再用小节 2.4 讲的动态路由算法迭代 3 次计算 cij并输出 10 个 vj。
DigitCaps代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 | class DigitCaps(nn.Module): """Digit capsule layer.""" def __init__( self , in_dim, in_caps, out_caps, out_dim, num_routing): """ Initialize the layer. Args: in_dim: Dimensionality of each capsule vector. in_caps: Number of input capsules if digits layer. out_caps: Number of capsules in the capsule layer out_dim: Dimensionality, of the output capsule vector. num_routing: Number of iterations during routing algorithm """ super (DigitCaps, self ).__init__() self .in_dim = in_dim self .in_caps = in_caps self .out_caps = out_caps self .out_dim = out_dim self .num_routing = num_routing self .device = device self .W = nn.Parameter( 0.01 * torch.randn( 1 , out_caps, in_caps, out_dim, in_dim), requires_grad = True ) def forward( self , x): batch_size = x.size( 0 ) # (batch_size, in_caps, in_dim) -> (batch_size, 1, in_caps, in_dim, 1) x = x.unsqueeze( 1 ).unsqueeze( 4 ) # W @ x = # (1, out_caps, in_caps, out_dim, in_dim) @ (batch_size, 1, in_caps, in_dim, 1) = # (batch_size, out_caps, in_caps, out_dims, 1) u_hat = torch.matmul( self .W, x) # (batch_size, out_caps, in_caps, out_dim) u_hat = u_hat.squeeze( - 1 ) # detach u_hat during routing iterations to prevent gradients from flowing temp_u_hat = u_hat.detach() b = torch.zeros(batch_size, self .out_caps, self .in_caps, 1 ).to( self .device) for route_iter in range ( self .num_routing - 1 ): # (batch_size, out_caps, in_caps, 1) -> Softmax along out_caps c = b.softmax(dim = 1 ) # element-wise multiplication # (batch_size, out_caps, in_caps, 1) * (batch_size, in_caps, out_caps, out_dim) -> # (batch_size, out_caps, in_caps, out_dim) sum across in_caps -> # (batch_size, out_caps, out_dim) s = (c * temp_u_hat). sum (dim = 2 ) # apply "squashing" non-linearity along out_dim v = squash(s) # dot product agreement between the current output vj and the prediction uj|i # (batch_size, out_caps, in_caps, out_dim) @ (batch_size, out_caps, out_dim, 1) # -> (batch_size, out_caps, in_caps, 1) uv = torch.matmul(temp_u_hat, v.unsqueeze( - 1 )) b + = uv # last iteration is done on the original u_hat, without the routing weights update c = b.softmax(dim = 1 ) s = (c * u_hat). sum (dim = 2 ) # apply "squashing" non-linearity along out_dim v = squash(s) return v |
【损失函数】
下标
总的损失是各个样例损失之和。论文中
如果 k 类存在,
如果 k 类不存在,
惩罚假阳性的重要性大概是惩罚假阴性的重要性的 2 倍
损失函数定义如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 | class CapsuleLoss(nn.Module): """Combine margin loss & reconstruction loss of capsule network.""" def __init__( self , upper_bound = 0.9 , lower_bound = 0.1 , lmda = 0.5 ): super (CapsuleLoss, self ).__init__() self .upper = upper_bound self .lower = lower_bound self .lmda = lmda self .reconstruction_loss_scalar = 5e - 4 self .mse = nn.MSELoss(reduction = 'sum' ) def forward( self , images, labels, logits, reconstructions): # Shape of left / right / labels: (batch_size, num_classes) left = ( self .upper - logits).relu() * * 2 # True negative right = (logits - self .lower).relu() * * 2 # False positive margin_loss = torch. sum (labels * left) + self .lmda * torch. sum (( 1 - labels) * right) # Reconstruction loss reconstruction_loss = self .mse(reconstructions.contiguous().view(images.shape), images) # Combine two losses return margin_loss + self .reconstruction_loss_scalar * reconstruction_loss |
【训练步骤】
话不多说,代码奉上:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | def train(model, train_loader, test_loader, args, device): criterion = CapsuleLoss() optimizer = Adam(model.parameters(), lr = args.lr) # from torch.optim import Adam, lr_scheduler scheduler = lr_scheduler.ExponentialLR(optimizer, gamma = args.weight_delay) model.train() for epoch in range (args.epochs): correct, total, total_loss = 0 , 0 , 0. for x, y in tqdm(train_loader): #tqdm is the funtion of proccessing bar optimizer.zero_grad() x = x.to(device) y = torch.eye( 10 ).index_select(dim = 0 , index = y).to(device) y_pred, x_recon = model(x) loss = criterion(x, y, y_pred, x_recon) correct + = torch. sum ( torch.argmax(y_pred, dim = 1 ) = = torch.argmax(y, dim = 1 ) ).item() total + = len (y) accuracy = correct / total total_loss + = loss loss.backward() optimizer.step() scheduler.step(epoch) val_loss, val_accuracy = evaluate(model = model, test_loader = test_loader, device = device) #evaluate funtion will introduction below print ( 'Epoch: %d, train_accuracy: %.2f , val_loss: %.2f , val_accuracy: %.2f ' % ((epoch + 1 ), accuracy * 100 , val_loss * 100 , val_accuracy * 100 )) |
【一些补充工具】
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 | def load_mnist(path = './datas' , download = True , batch_size = 128 , shift_pixels = 2 ): kwargs = { 'num_workers' : 4 , 'pin_memory' : True } transform = transforms.Compose([ # shift by 2 pixels in either direction with zero padding. transforms.RandomCrop( 28 , padding = 2 ), transforms.ToTensor(), transforms.Normalize(( 0.1307 ,), ( 0.3081 ,)) ]) train_loader = torch.utils.data.DataLoader( datasets.MNIST(root = path, download = download, train = True , transform = transform), batch_size = batch_size, shuffle = True , * * kwargs) test_loader = torch.utils.data.DataLoader( datasets.MNIST(root = path, download = download, train = False , transform = transform), batch_size = batch_size, shuffle = True , * * kwargs) return train_loader, test_loader def evaluate(model, test_loader, device): model. eval () correct, total = 0 , 0 for images, labels in test_loader: # Add channels = 1 images = images.to(device) # Categogrical encoding labels = torch.eye( 10 ).index_select(dim = 0 , index = labels).to(device) logits, reconstructions = model(images) pred_labels = torch.argmax(logits, dim = 1 ) correct + = torch. sum (pred_labels = = torch.argmax(labels, dim = 1 )).item() total + = len (labels) return 1 - correct / total, correct / total if __name__ = = '__main__' : args, option = getParams() device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) model = CapsuleNet() model = model.to(device) print (model) train_loader, test_loader = load_mnist(batch_size = args.batch_size) train(model, train_loader, test_loader, args, device) |
【参考文献】
-
Sabour S, Frosst N, Hinton G E. Dynamic Routing Between Capsules[J]. NIPS2017
【附Code GitHub】
传送点:Click Me
环境Docker Hub: Click Me