question?

  • 为什么要使用transform?
  • 怎么使用transform?

question1

加快运算,使用GPU运算,加快计算速度!

question2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project :Pytorch_learn
@File :transform_1.py
@IDE :PyCharm
@Author :咋
@Date :2023/6/29 18:16
"""
from torchvision import transforms

from PIL import Image

image_path = "data\\train\\ants_image\\5650366_e22b7e1065.jpg"

image = Image.open(image_path)
transform_tool = transforms.ToTensor() # 创建一个transform工具
image_tensor = transform_tool(image)
print(image_tensor)


# output
# tensor([[[0.3804, 0.3804, 0.3843, ..., 0.3412, 0.3373, 0.3333],
# [0.3765, 0.3804, 0.3843, ..., 0.3529, 0.3490, 0.3451],
# [0.3804, 0.3804, 0.3843, ..., 0.3725, 0.3686, 0.3647],
# ...,
# [0.6078, 0.6078, 0.6118, ..., 0.4627, 0.4627, 0.4627],
# [0.5882, 0.5922, 0.5922, ..., 0.4588, 0.4588, 0.4588],
# [0.5804, 0.5804, 0.5843, ..., 0.4549, 0.4549, 0.4549]],
#
# [[0.4667, 0.4667, 0.4706, ..., 0.4039, 0.4000, 0.3961],
# [0.4706, 0.4667, 0.4706, ..., 0.3922, 0.3882, 0.3843],
# [0.4745, 0.4745, 0.4784, ..., 0.3804, 0.3765, 0.3725],
# ...,
# [0.5961, 0.5961, 0.6000, ..., 0.4588, 0.4588, 0.4588],
# [0.5882, 0.5922, 0.5922, ..., 0.4549, 0.4549, 0.4549],
# [0.5804, 0.5804, 0.5804, ..., 0.4510, 0.4510, 0.4510]],
#
# [[0.4157, 0.4157, 0.4196, ..., 0.3608, 0.3569, 0.3529],
# [0.4196, 0.4157, 0.4196, ..., 0.3569, 0.3529, 0.3490],
# [0.4235, 0.4235, 0.4235, ..., 0.3608, 0.3569, 0.3529],
# ...,
# [0.5608, 0.5608, 0.5647, ..., 0.4392, 0.4392, 0.4392],
# [0.5412, 0.5529, 0.5608, ..., 0.4353, 0.4353, 0.4353],
# [0.5333, 0.5412, 0.5608, ..., 0.4314, 0.4314, 0.4314]]])

transform其实就是一个工具箱,我们可以用transform提供的工具制造我们自己的工具。

常见的transform

1.tosensor

1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
from torchvision import transforms
from PIL import Image

image_path = "test.jpg"
image = Image.open(image_path)


# totensor
tran_tensor = transforms.ToTensor()
image_tensor = tran_tensor(image)
print(type(image))
print(type(image_tensor))

将opencv读取的numpy对象,PIL读入的Image对象转成tensor对象。

2.normalize

1
2
3
4
5
6
# normalize
tran_norm = transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
image_norm = tran_norm(image_tensor)

print(image)
print(image_norm)

标准化为了使数据具有相似的尺度和分布,以便更好地进行模型训练。通过规范化,可以减少数据之间的差异,提高模型的收敛速度和稳定性。

3.compose

1
2
3
4
5
6
7
8
9
10
#compose

tran_com = transforms.Compose(
[tran_tensor,
tran_norm]
)

image_compose = tran_com(image)
print(image)
print(image_compose)

将不同的transform模块拼接起来,主要不同模块输入与输出,可以用ctrl+点击该函数进入,查看函数的输入输出类型。

4.randomcrop

1
2
3
4
5
6
7
8
9
10
11
# randomcrop

tran_crop = transforms.RandomCrop(512)
write = SummaryWriter("log_1")

for i in range(10):
image_crop = tran_crop(image_tensor)
write.add_image("change", image_crop, i)

# tensorboard --logdir=logs --port=6007
write.close()

随机剪裁图片,可以用tensorboard查看
image.png

resize

1
2
3
4
5
6
7
8
# resize
tran_resize = transforms.Resize((512,512))
image_resize = tran_resize(image_tensor)
print(image_tensor.shape)
print(image_resize.shape)
# return
# torch.Size([3, 512, 768])
# torch.Size([3, 512, 512])