病虫害识别以及分类的实现

模型生成

数据集

数据集地址:IP102-Dataset (kaggle.com)

数据集说明:

IP102 是一个用于害虫识别的大规模基准数据集。以下是它的一些主要特点:

  • 图像数量:包含超过 75,000 张图像
  • 类别:涵盖 102 个不同的昆虫害虫类别
  • 数据分布:呈现自然的长尾分布,模拟了现实世界中的不平衡样本情况
  • 目标检测:为约 19,000 张图像标注了边界框,用于目标检测任务

这个数据集的独特之处在于它的层次分类系统和数据分布不平衡的特点,使其在害虫识别和农业应用中具有重要意义。

训练模型

具体训练过程:Pytorch_VIT Insect Classifier (kaggle.com)

可以使用kaggle里的notebook,也可以使用colab在云端运行

点击Edit My Copy,可以直接进入notebook,但是由于要下载timm库,要在notebook中开启网络功能,需要手机验证。

我搞了半天验证,一直出问题。。。后面选择使用colab。

Colab

先在notebook的导航栏中找到file选项,里面有一个open in colab选项,可以直接将项目转到colab中。

在跑代码之前,要注意硬件加速器选择GPU

具体操作方法:

  1. 导航栏中的修改选项
  2. 笔记本设置
  3. 选择GPU

(没选这玩意会出错,之前跑一半才发现没开GPU)

然后依次执行代码,全部执行完后,就会获得vit_best.pth这个模型文件了

详细的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import numpy as np
import pandas as pd
import os
import random
from tqdm import tqdm
from textwrap import wrap

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import cv2
import matplotlib.pyplot as plt
import seaborn as sns

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import timm

# 读取类别文件
f = open('../input/ip02-dataset/classes.txt')
label = []
name = []
for line in f.readlines():
label.append(int(line.split()[0]))
name.append(' '.join(line.split()[1:]))
classes = pd.DataFrame([label, name]).T
classes.columns = ['label','name']
classes

# 读取训练集、测试集和验证集
train_df = pd.read_csv('../input/ip02-dataset/train.txt',sep=' ',header=None, engine='python')
train_df.columns = ['image_path','label']

test_df = pd.read_csv('../input/ip02-dataset/test.txt',sep=' ',header=None, engine='python')
test_df.columns = ['image_path','label']

val_df = pd.read_csv('../input/ip02-dataset/val.txt',sep=' ',header=None, engine='python')
val_df.columns = ['image_path','label']

train_df.head()

# 定义目录和超参数
TRAIN_DIR = '../input/ip02-dataset/classification/train'
TEST_DIR = '../input/ip02-dataset/classification/test'
VAL_DIR = '../input/ip02-dataset/classification/val'
LR = 2e-5
BATCH_SIZE = 8
EPOCH = 2

device = torch.device('cuda')

# 可视化部分训练数据
fig, axs = plt.subplots(10,11,figsize=(30,30))
images = []
for i in classes.label:
random_img = random.choice(train_df[train_df.label==i-1].image_path.values)
label = classes.name[i-1]
img = plt.imread(os.path.join(TRAIN_DIR,str(i-1),random_img))
images.append(img)

[ax.imshow(image) for image,ax in zip(images,axs.ravel())]
[ax.set_title("\n".join(wrap(label,20))) for label,ax in zip(list(classes.name),axs.ravel())]
[ax.set_axis_off() for ax in axs.ravel()]
plt.show()

# 定义模型
class InsectModel(nn.Module):
def __init__(self,num_classes):
super(InsectModel, self).__init__()
self.num_classes = num_classes
self.model = timm.create_model('vit_base_patch16_224',pretrained=True,num_classes=num_classes)
def forward(self, image):
return self.model(image)

# 定义数据增强
def train_transform():
return A.Compose([
A.HorizontalFlip(),
A.RandomRotate90(),
A.RandomBrightnessContrast(),
A.Resize(224, 224),
ToTensorV2()])

def valid_transform():
return A.Compose([
A.Resize(224,224),
ToTensorV2()])

def collate_fn(batch):
return tuple(zip(*batch))

# 定义数据集
class InsectDataset(Dataset):
def __init__(self, image, image_dir, transforms=None):
self.image_info = image
self.transforms = transforms
self.imgdir = image_dir
def __len__(self):
return self.image_info.shape[0]
def __getitem__(self, index):
image_info = self.image_info[index]
image = cv2.imread(os.path.join(self.imgdir,str(image_info[1]),image_info[0]),cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
image /= 255.

if self.transforms is not None:
image = self.transforms(image = image)['image']

label = image_info[1]

image = torch.as_tensor(image, dtype=torch.float32)
label = torch.as_tensor(label, dtype=torch.long)
return image, label

# 创建数据加载器
train_dataset = InsectDataset(image=train_df.values,
image_dir=TRAIN_DIR,
transforms=train_transform())
train_data_loader = DataLoader(train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=2)
val_dataset = InsectDataset(image=val_df.values,
image_dir=VAL_DIR,
transforms=valid_transform())
val_data_loader = DataLoader(val_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=2)

# 定义平均计量器
class AverageMeter(object):
def __init__(self):
self.reset()

def reset(self):
self.loss = 0
self.correct = 0
self.avg = 0
self.sum = 0
self.count = 0

def update(self, loss,correct, n=1):
self.loss = loss
self.correct += correct
self.sum += loss * n
self.count += n

self.avg = self.sum / self.count
self.acc = self.correct / self.count

class Accuracy(object):
def __init__(self):
self.reset

# 定义训练函数
def train_fn(data_loader, model, criterion, device, optimizer, epoch):
model.train()
criterion.train()

summary = AverageMeter()
tk0 = tqdm(data_loader, total=len(data_loader))
for step, (images, labels) in enumerate(tk0):
images = images.to(device, non_blocking = True).float()
labels = labels.to(device, non_blocking = True).long()

output = model(images)
loss = criterion(output, labels)

optimizer.zero_grad()
loss.backward()
optimizer.step()

preds = output.softmax(1).argmax(1)
correct = (preds == labels).sum().item()

summary.update(loss.item(),correct, BATCH_SIZE)
tk0.set_postfix(loss=summary.avg, acc=summary.acc, epoch=epoch+1)
return summary

# 定义验证函数
def eval_fn(data_loader, model, criterion, device, epoch):
model.eval()
criterion.eval()

summary = AverageMeter()
tk0 = tqdm(data_loader, total=len(data_loader))
with torch.no_grad():
for step, (images, labels) in enumerate(tk0):
images = images.to(device, non_blocking = True).float()
labels = labels.to(device, non_blocking = True).long()

output = model(images)
loss = criterion(output, labels)

preds = output.softmax(1).argmax(1)
correct = (preds == labels).sum().item()

summary.update(loss.item(), correct, BATCH_SIZE)
tk0.set_postfix(loss=summary.avg, acc=summary.acc, epoch=epoch+1)
return summary

os.environ['WANDB_CONSOLE'] = 'off'

# 运行训练和验证
def run():
model = InsectModel(num_classes=102)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
criterion = criterion.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
best_loss = 10**5
for epoch in range(0, EPOCH):
train_loss = train_fn(train_data_loader, model, criterion, device, optimizer, epoch)
val_loss = eval_fn(val_data_loader, model, criterion, device, epoch)
if val_loss.avg < best_loss:
best_loss = val_loss.avg
torch.save(model.state_dict(), f'vit_best.pth')
print(f'Epoch {epoch+1+0:03}: | Train Loss: {train_loss.avg:.5f} | Val Loss: {val_loss.avg:.5f}')

run()

# 加载最佳模型并进行预测
model = InsectModel(num_classes=102)
model.load_state_dict(torch.load("./vit_best.pth"))
images, labels = next(iter(val_data_loader))
preds = model(images).softmax(1).argmax(1)

# 可视化预测结果
fig, axs = plt.subplots(2,4,figsize=(13,8))
[ax.imshow(image.permute((1,2,0))) for image,ax in zip(images,axs.ravel())]
[ax.set_title("\n".join(wrap(f'实际: {classes.name[label.item()]} 预测: {classes.name[pred.item()]}',30)),color = 'g' if label.item()==pred.item() else 'r') for label,pred,ax in zip(labels,preds,axs.ravel())]
[ax.set_axis_off() for ax in axs.ravel()]
plt.show()

模型使用

结合我博客中的手写体案例,视频捕获以及成像都是一样的,但是由于之前的模型是tflite格式的,而现在的模型是pth格式的,而且类别繁多,需要进行部分的修改,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import cv2
import torch
import time
import numpy as np
import pandas as pd
import timm
import torch.nn as nn
from textwrap import wrap

# 视频流地址
url = "http://192.168.2.8/stream"

# 打开摄像头
cap = cv2.VideoCapture(url)

# 加载模型
class InsectModel(nn.Module):
def __init__(self, num_classes):
super(InsectModel, self).__init__()
self.num_classes = num_classes
self.model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=num_classes)
def forward(self, image):
return self.model(image)

try:
model = InsectModel(num_classes=102)
model.load_state_dict(torch.load("./vit_best.pth"))
model.eval()
print("模型加载成功")
except Exception as e:
print(f"模型加载失败: {e}")

# 定义图像预处理函数
def preprocess(image):
try:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
image /= 255.0
image = cv2.resize(image, (224, 224))
image = torch.tensor(image).permute(2, 0, 1).unsqueeze(0)
print("图像预处理成功")
return image
except Exception as e:
print(f"图像预处理失败: {e}")
return None

# 读取类别名称
try:
f = open('classes.txt')
label = []
name = []
for line in f.readlines():
label.append(int(line.split()[0]))
name.append(' '.join(line.split()[1:]))
classes = pd.DataFrame([label, name]).T
classes.columns = ['label', 'name']
print("类别名称读取成功")
except Exception as e:
print(f"类别名称读取失败: {e}")

last_recognition_time = time.time()

while True:
# 读取摄像头的帧
ret, frame = cap.read()

if not ret:
print("无法读取视频流,尝试重新连接...")
cap.release()
cap = cv2.VideoCapture(url)
continue

# 在窗口中显示帧
cv2.imshow('Camera', frame)

# 每隔10秒进行一次识别
if time.time() - last_recognition_time > 10:
image = preprocess(frame)
if image is not None:
with torch.no_grad():
output = model(image)
pred = output.softmax(1).argmax(1).item()
if pred < len(classes):
label_name = classes.name[pred]
# else:
# label_name = "无法识别" 暂时不起作用
print(f"识别结果: {label_name}")
last_recognition_time = time.time()

# 按下'q'键退出循环
if cv2.waitKey(1) & 0xFF == ord('q'):
break

# 释放摄像头并关闭窗口
cap.release()
cv2.destroyAllWindows()

注意:

  1. vit_best.pth就是刚刚生成的模型文件

  2. classes.txt是数据集内的同名文件

  3. if not ret:
        print("无法读取视频流,尝试重新连接...")
        cap.release()
        cap = cv2.VideoCapture(url)
        continue
    

    这段代码可以防止Stream ends prematurely at ......报错