layout_data.models.fpn.model_init 源代码

# encoding: utf-8
import torch


[文档]def weights_init(m): """ 模型的权重初始化函数,由模型调用,如CRNN model :param m: 待初始化的模型 nn.Module :return: """ class_name = m.__class__.__name__ if class_name.find("Conv") != -1: torch.nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") # 初始化卷积层权重 # torch.nn.init.xavier_normal_(m.weight) elif (class_name.find("BatchNorm") != -1 and class_name.find("WithFixedBatchNorm") == -1 ): # batch norm层不能用kaiming_normal初始化 torch.nn.init.constant_(m.weight, 1) torch.nn.init.constant_(m.bias, 0) # m.weight.data.normal_(1.0, 0.02) # m.bias.data.fill_(0) elif class_name.find("Linear") != -1: torch.nn.init.xavier_normal_(m.weight.data) if m.bias is not None: m.bias.data.fill_(0) elif class_name.find("LSTM") != -1 or class_name.find("LSTMCell") != -1: for name, param in m.named_parameters(): if "weight_ih" in name: torch.nn.init.xavier_uniform_(param.data) elif "weight_hh" in name: torch.nn.init.orthogonal_(param.data) elif "bias" in name: param.data.fill_(0)
[文档]def weights_init_without_kaiming(m): """ 模型的权重初始化函数,由模型调用,如CRNN model :param m: 待初始化的模型 nn.Module :return: """ class_name = m.__class__.__name__ if class_name.find("Conv") != -1: torch.nn.init.xavier_normal_(m.weight) # torch.nn.init.normal_(m.weight) # 初始化卷积层权重 elif (class_name.find("BatchNorm") != -1 and class_name.find("WithFixedBatchNorm") == -1 ): # batch norm层不能用kaiming_normal初始化 torch.nn.init.constant_(m.weight, 1) torch.nn.init.constant_(m.bias, 0) # m.weight.data.normal_(1.0, 0.02) # m.bias.data.fill_(0) elif class_name.find("Linear") != -1: torch.nn.init.xavier_normal_(m.weight.data) if m.bias is not None: m.bias.data.fill_(0) elif class_name.find("LSTM") != -1 or class_name.find("LSTMCell") != -1: for name, param in m.named_parameters(): if "weight_ih" in name: torch.nn.init.xavier_uniform_(param.data) elif "weight_hh" in name: torch.nn.init.orthogonal_(param.data) elif "bias" in name: param.data.fill_(0)