高性能PyTorch是如何炼成的?整理的10条脱坑指南
将 RGB 图像保持在每个通道深度 8 位。可以轻松地在 GPU 上将图像转换为浮点形式或者标准化。 在数据集中用 uint8 或 uint16 数据类型代替 long。 class MySegmentationDataset(Dataset): ... def __getitem__(self, index): image = cv2.imread(self.images[index]) target = cv2.imread(self.masks[index]) # No data normalization and type casting here return torch.from_numpy(image).permute(2,0,1).contiguous(), torch.from_numpy(target).permute(2,0,1).contiguous() class Normalize(nn.Module): # https://github.com/BloodAxe/pytorch-toolbelt/blob/develop/pytorch_toolbelt/modules/normalize.py def __init__(self, mean, std): super().__init__() self.register_buffer("mean", torch.tensor(mean).float().reshape(1, len(mean), 1, 1).contiguous()) self.register_buffer("std", torch.tensor(std).float().reshape(1, len(std), 1, 1).reciprocal().contiguous()) def forward(self, input: torch.Tensor) -> torch.Tensor: return (input.to(self.mean.type) - self.mean) * self.std class MySegmentationModel(nn.Module): def __init__(self): self.normalize = Normalize([0.221 * 255], [0.242 * 255]) self.loss = nn.CrossEntropyLoss() def forward(self, image, target): image = self.normalize(image) output = self.backbone(image) if target is not None: loss = self.loss(output, target.long()) return loss return output (编辑:西安站长网) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |