Deeplearning
DataSet 的相关内容
Dataset: PyTorch的数据集类
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28from torch.utils.data import Dataset # Dataset: PyTorch的数据集类
from PIL import Image # PIL: Python的图像处理库,Pillow是PIL的一个分支
import os # os: Python的标准库,提供了丰富的方法来处理文件和目录
class MyData(Dataset): # 继承Dataset类
def __init__(self, root_dir, lable_dir): # 初始化函数
self.root_dir = root_dir # 数据集的根目录
self.lable_dir = lable_dir # 数据集的标签
self.path = os.path.join(self.root_dir, self.lable_dir) # 数据集的路径
self.img_path = os.listdir(self.path) # 数据集的图片路径 返回指定的文件夹包含的文件或文件夹的名字的列表
def __len__(self): # 返回数据集的长度
return len(self.img_path)
def __getitem__(self, idx): # 返回数据集的一个样本
img_name = self.img_path[idx] # 获取图片的名称
img_item_path = os.path.join(self.path, img_name) # 获取图片的路径
img = Image.open(img_item_path) # 读取图片
lable = self.lable_dir # 获取图片的标签
return img, lable # 返回图片和标签
root_dir = 'dataset/train' # 数据集的根目录
ants_root_dir = 'ants_image' # 蚂蚁的数据集
bees_root_dir = 'bees_image' # 蜜蜂的数据集
ants_dataset = MyData(root_dir, ants_root_dir) # 蚂蚁的数据集
bees_dataset = MyData(root_dir, bees_root_dir) # 蜜蜂的数据集据训练集生成标签集lable
1
2
3
4
5
6
7
8
9
10
11import os
root_dir = "dataset/train"
target_dir = "bees_image"
img_path = os.listdir(os.path.join(root_dir, target_dir))
lable = target_dir.split('_')[0]
out_dir = "bees_lable"
for i in img_path:
file_name = i.split('.jpg')[0]
with open(os.path.join(root_dir, out_dir, "{}.txt".format(file_name)), 'w') as f:
f.write(lable)
TensorBoard 的相关操作
- 如何打开Tensorboard 的logs事件文件,port是指定打开的端口号,默认是6006 documentation
1
tensorboard --logdir=logs --port=6006
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19from torch.utils.tensorboard import SummaryWriter # Summary Writer是一个用于写入TensorBoard的类
import numpy as np
from PIL import Image
writer = SummaryWriter('logs') # 创建一个SummaryWriter对象,指定存储路径写入logs文件夹事件文件
## PIL图像类型转为numpy类型
img_PIL = Image.open("dataset/train/ants_image/0013035.jpg") # 打开一张图片
print(type(img_PIL))
img_array = np.array(img_PIL) # 将PIL图像转为numpy数组
print(type(img_array))
writer.add_image("test", img_array, 1, dataformats='HWC') # 添加图像到日志文件中 dataformats='HWC'表示数据格式为高度、宽度、通道数
for i in range(100):
writer.add_scalar("y=x", i, i) # 添加标量数据 第一个参数是标签,第二个参数是数据,第三个参数是迭代次数 参数标签手动输入
writer.close() # 关闭SummaryWriter对象
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来源 e哥の自我修养!