# encoding: utf-8
import math
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
from .fpn_head import FPNDecoder
from .resnet import resnet50
from .model_init import weights_init
from ..scheduler import WarmupMultiStepLR
from pytorch_lightning import LightningModule
from layout_data.data.layout import LayoutDataset
import layout_data.utils.np_transforms as transforms
[文档]class FPNModel(LightningModule):
def __init__(self, hparams):
super().__init__()
self.hparams = hparams
self._build_model()
self.criterion = nn.L1Loss()
self.train_dataset = None
self.val_dataset = None
self.test_dataset = None
def _build_model(self):
self.backbone = resnet50()
self.head = FPNDecoder(encoder_channels=[2048, 1024, 512, 256])
self.backbone.apply(weights_init)
[文档] def forward(self, x):
x = self.backbone(x)
x = self.head(x)
# x = torch.sigmoid(x)
return x
def __dataloader(self, dataset):
loader = DataLoader(
dataset=dataset,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
)
return loader
[文档] def prepare_data(self):
"""Prepare dataset
"""
size: int = self.hparams.input_size
transform_layout = transforms.Compose(
[
transforms.Resize(size=(size, size)),
transforms.ToTensor(),
transforms.Normalize(
torch.tensor([self.hparams.mean_layout]),
torch.tensor([self.hparams.std_layout]),
),
]
)
transform_heat = transforms.Compose(
[
transforms.Resize(size=(size, size)),
transforms.ToTensor(),
transforms.Normalize(
torch.tensor([self.hparams.mean_heat]),
torch.tensor([self.hparams.std_heat]),
),
]
)
train_dataset = LayoutDataset(
self.hparams.data_root,
train=True,
transform=transform_layout,
target_transform=transform_heat,
)
test_dataset = LayoutDataset(
self.hparams.data_root,
train=False,
transform=transform_layout,
target_transform=transform_heat,
)
# train/val split
train_size = int(self.hparams.train_size * len(train_dataset))
lengths = [train_size, len(train_dataset) - train_size]
train_dataset, val_dataset = torch.utils.data.random_split(
train_dataset, lengths
)
print(
f"Prepared dataset, train:{len(train_dataset)},\
val:{len(val_dataset)}, test:{len(test_dataset)}"
)
# assign to use in dataloaders
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.test_dataset = test_dataset
[文档] def train_dataloader(self):
return self.__dataloader(self.train_dataset)
[文档] def val_dataloader(self):
return self.__dataloader(self.val_dataset)
[文档] def test_dataloader(self):
return self.__dataloader(self.test_dataset)
[文档] def training_step(self, batch, batch_idx):
layout, heat = batch
heat_pred = self(layout)
loss = self.criterion(heat, heat_pred)
log = {"training_loss": loss}
if batch_idx == 0:
grid = torchvision.utils.make_grid(heat_pred[:4, ...], normalize=True)
self.logger.experiment.add_image(
"train_pred_heat_field", grid, self.global_step
)
if self.global_step == 0:
grid = torchvision.utils.make_grid(heat[:4, ...], normalize=True)
self.logger.experiment.add_image(
"train_heat_field", grid, self.global_step
)
return {"loss": loss, "log": log}
[文档] def validation_step(self, batch, batch_idx):
layout, heat = batch
heat_pred = self(layout)
loss = self.criterion(heat, heat_pred)
# pred heat field
grid = torchvision.utils.make_grid(heat_pred[:4, ...], normalize=True)
self.logger.experiment.add_image("val_pred_heat_field", grid, self.global_step)
# true layoutand heat field
if self.global_step == 0 and batch_idx == 0:
grid = torchvision.utils.make_grid(heat[:4, ...], normalize=True)
self.logger.experiment.add_image("val_heat_field", grid, self.global_step)
grid = torchvision.utils.make_grid(layout[:4, ...], normalize=True)
self.logger.experiment.add_image("val_layout_field", grid, self.global_step)
return {"val_loss": loss}
[文档] def validation_epoch_end(self, outputs):
val_loss_mean = torch.stack([x["val_loss"] for x in outputs]).mean()
log = {"val_loss": val_loss_mean}
return {"val_loss": val_loss_mean, "log": log}
[文档] def test_step(self, batch, batch_idx):
layout, heat = batch
heat_pred = self(layout)
loss = self.criterion(heat, heat_pred)
return {"test_loss": loss}
[文档] def test_epoch_end(self, outputs):
test_loss_mean = torch.stack([x["test_loss"] for x in outputs]).mean()
tqdm_dict = {"test_loss_mean": test_loss_mean.item()}
log = {"test_loss": test_loss_mean.item()}
return {"progress_bar": tqdm_dict, "log": log}
[文档] @staticmethod
def add_model_specific_args(parser): # pragma: no-cover
"""
Parameters you define here will be available to your model
through `self.hparams`.
"""
parser = parser
# dataset args
parser.add_argument("--data_root", type=str, default="d:/work/dataset")
parser.add_argument(
"--train_size",
default=0.8,
type=float,
help="train_size in train_test_split",
)
# network params
parser.add_argument("--drop_prob", default=0.2, type=float)
parser.add_argument("--input_size", default=200, type=int)
parser.add_argument("--mean_layout", default=0, type=float)
parser.add_argument("--std_layout", default=1, type=float)
parser.add_argument("--mean_heat", default=0, type=float)
parser.add_argument("--std_heat", default=1, type=float)
# training params (opt)
parser.add_argument("--optimizer_name", default="adam", type=str)
parser.add_argument("--lr", default="0.01", type=float)
parser.add_argument("--batch_size", default=16, type=int)
parser.add_argument(
"--num_workers", default=2, type=int, help="num_workers in DatasetLoader"
)
return parser