layout_data.utils.np_transforms 源代码

# -*- encoding: utf-8 -*-
"""
Desc      :   Transforms.
"""
# File    :   np_transforms.py
# Time    :   2020/04/06 17:24:54
# Author  :   Zweien
# Contact :   278954153@qq.com

import cv2
import torch
from torchvision import transforms


[文档]class ToTensor: """Transform np.array to torch.tensor Args: add_dim (bool, optional): add first dim. Defaults to True. type_ (torch.dtype, optional): dtype of the tensor. Defaults to tensor.torch.float32. Returns: torch.tensor: tensor """ def __init__(self, add_dim=True, type_=torch.float32): self.add_dim = add_dim self.type = type_ def __call__(self, x): if self.add_dim: return torch.tensor(x, dtype=self.type).unsqueeze(0) return torch.tensor(x, dtype=self.type)
[文档]class Resize: def __init__(self, size): self.size = size def __call__(self, x): return cv2.resize(x, self.size)
[文档]class Lambda(transforms.Lambda): pass
[文档]class Compose(transforms.Compose): pass
[文档]class Normalize(transforms.Normalize): pass