question?
- 为什么要使用transform?
- 怎么使用transform?
question1
加快运算,使用GPU运算,加快计算速度!
question2
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
|
""" @Project :Pytorch_learn @File :transform_1.py @IDE :PyCharm @Author :咋 @Date :2023/6/29 18:16 """ from torchvision import transforms
from PIL import Image
image_path = "data\\train\\ants_image\\5650366_e22b7e1065.jpg"
image = Image.open(image_path) transform_tool = transforms.ToTensor() image_tensor = transform_tool(image) print(image_tensor)
|
transform其实就是一个工具箱,我们可以用transform提供的工具制造我们自己的工具。
1.tosensor
1 2 3 4 5 6 7 8 9 10 11 12 13
| import torch from torchvision import transforms from PIL import Image
image_path = "test.jpg" image = Image.open(image_path)
tran_tensor = transforms.ToTensor() image_tensor = tran_tensor(image) print(type(image)) print(type(image_tensor))
|
将opencv读取的numpy对象,PIL读入的Image对象转成tensor对象。
2.normalize
1 2 3 4 5 6
| tran_norm = transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5]) image_norm = tran_norm(image_tensor)
print(image) print(image_norm)
|
标准化为了使数据具有相似的尺度和分布,以便更好地进行模型训练。通过规范化,可以减少数据之间的差异,提高模型的收敛速度和稳定性。
3.compose
1 2 3 4 5 6 7 8 9 10
|
tran_com = transforms.Compose( [tran_tensor, tran_norm] )
image_compose = tran_com(image) print(image) print(image_compose)
|
将不同的transform模块拼接起来,主要不同模块输入与输出,可以用ctrl+点击该函数进入,查看函数的输入输出类型。
4.randomcrop
1 2 3 4 5 6 7 8 9 10 11
|
tran_crop = transforms.RandomCrop(512) write = SummaryWriter("log_1")
for i in range(10): image_crop = tran_crop(image_tensor) write.add_image("change", image_crop, i)
write.close()
|
随机剪裁图片,可以用tensorboard查看

resize
1 2 3 4 5 6 7 8
| tran_resize = transforms.Resize((512,512)) image_resize = tran_resize(image_tensor) print(image_tensor.shape) print(image_resize.shape)
|