一、项目概述
大家好!今天我将分享一个我近期完成的深度学习项目——一个功能强大的、带图形化界面(GUI)的水果识别系统。该系统不仅能识别静态图片中的水果,还集成了模型训练、评估、数据增强等功能于一体,为深度学习的入门和实践提供了一个绝佳的案例。
本项目使用 Python 作为主要开发语言,后端算法基于 TensorFlow/Keras 深度学习框架,前端界面则采用 PyQt5 构建,实现了算法与应用的分离,界面美观,交互友好。
核心技术栈:
- GUI框架: PyQt5
- 深度学习框架: TensorFlow 2.x / Keras
- 计算机视觉库: OpenCV-Python
- 数据可视化: Matplotlib
- 核心模型:
- 自定义的轻量级 CNN
- 基于迁移学习的 MobileNetV2
- 基于迁移学习的 VGG16
二、功能展示
系统主界面通过一个选项卡(QTabWidget
)清晰地划分了五大核心功能区。
1. 静态图片识别
用户可以选择本地的水果图片,然后从下拉列表中选择一个已训练好的模型(CNN, MobileNetV2, VGG16)进行识别。识别结果会立刻显示在界面右侧。
2. 实时视频识别(补充功能)
本系统支持通过本地视频文件或直接调用摄像头进行实时识别。在视频流的每一帧上,系统都会进行预测,并将结果实时绘制在画面上,非常直观。
3. 模型训练
这是系统的核心功能之一。用户可以直接在界面上点击按钮,启动对CNN、MobileNetV2或VGG16模型的训练。训练过程中的所有日志(Epoch、loss、accuracy等)都会实时显示在文本框中。训练结束后,准确率和损失曲线图会自动绘制并显示在右侧,同时新模型会被自动加载,可立即用于识别。
4. 模型评估
为了量化模型的性能,评估功能可以计算模型在验证集上的准确率,并生成一个详细的混淆矩阵(Confusion Matrix)热力图。这有助于我们分析模型对哪些类别的识别效果好,哪些容易混淆。
5. 数据增强
提供了一个一键数据增强的工具。它会遍历指定文件夹中的原始图片,通过旋转、平移、缩放、翻转等操作批量生成新的训练样本,有效扩充数据集,防止模型过拟合。
三、系统架构与代码解析
项目的代码结构清晰,每个文件各司其职。
main.ui.py
: 主程序入口和UI界面。负责创建所有窗口控件,处理用户交互事件,并使用QProcess
和QThread
调用后端的训练和识别脚本,避免了界面卡死。CNNTrain.py
: 自定义CNN模型的训练脚本。包含数据加载、模型构建、训练和保存的全过程。MobileNetTrain.py
: MobileNetV2模型的迁移学习训练脚本。VGG16Train.py
: VGG16模型的迁移学习训练脚本。testModel.py
: 模型评估脚本。负责加载模型,在验证集上进行测试,并生成混淆矩阵。geneImage.py
: 数据增强脚本。用于离线扩充数据集。
1. 核心亮点:迁移学习的应用 (MobileNetTrain.py
)
为了在有限的数据集上达到高精度,我们主要采用了迁移学习。以 MobileNetTrain.py
为例,我们加载了在ImageNet上预训练的MobileNetV2模型,并冻结其大部分权重,只训练我们自己添加的分类层。
def model_load(IMG_SHAPE=(224, 224, 3), class_num=15):# 加载预训练的MobileNetV2模型,不包含顶部分类层base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,include_top=False, # 关键:不加载全连接层weights='imagenet')# 冻结预训练模型的权重,在训练中不更新它们base_model.trainable = Falsemodel = tf.keras.models.Sequential([# 使用预训练的MobileNetV2作为基座base_model,# 对主干模型的输出进行全局平均池化tf.keras.layers.GlobalAveragePooling2D(),# 添加Dropout层,防止分类器过拟合tf.keras.layers.Dropout(0.5),# 添加我们自己的全连接分类层tf.keras.layers.Dense(class_num, activation='softmax')])# 编译模型model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])model.summary()return model
关键点:
include_top=False
:这是使用迁移学习的核心,我们只借用模型的特征提取部分。base_model.trainable = False
:冻结权重可以防止在小数据集上破坏预训练学到的通用特征。- 自定义分类头:在
base_model
之后添加了GlobalAveragePooling2D
、Dropout
和Dense
层,这是我们需要从头开始训练的部分。
2. 界面与逻辑分离 (main.ui.py
)
为了保证用户体验,耗时的任务(如模型训练和实时视频处理)不能阻塞UI主线程。
-
模型训练:通过
QProcess
启动一个外部Python进程来执行训练脚本。这样,训练过程与主界面完全分离,并且可以通过重定向标准输出来捕获日志。def run_script(self, script_name, args=None):# ... 省略部分代码 ...self.process = QProcess(self)# 连接信号与槽,用于读取输出self.process.readyReadStandardOutput.connect(lambda: self.handle_stdout(output_widget))self.process.finished.connect(...)# 启动外部脚本command = f'python -u {script_name}'self.process.start(command)
-
实时识别:通过
QThread
将视频的读取和模型预测放到一个工作线程中。工作线程完成一帧的预测后,通过pyqtSignal
发射一个信号,将处理好的图像(QImage
)传回主线程进行显示。class VideoWorker(QThread):change_pixmap_signal = pyqtSignal(QImage)def run(self):# ... 视频读取和模型预测 ...# 循环中if ret:# ...# 发射信号,将处理后的图像传给UI线程self.change_pixmap_signal.emit(qt_image)# 在主窗口中 self.video_thread.change_pixmap_signal.connect(self.update_frame)
3. 精细化的模型预处理
不同的预训练模型通常需要不同的输入预处理方式。例如,CNN模型通常需要将像素值归一化到[0, 1]
,而MobileNetV2和VGG16则有自己专用的preprocess_input
函数。我们的代码严格区分了这一点,确保模型在预测时接收到正确格式的数据。
# 在 main.ui.py 的 predict_image 方法中
if model_name == 'MobileNetV2':processed_array = mobilenet_preprocess_input(img_array)
elif model_name == 'VGG16':processed_array = vgg16_preprocess_input(img_array)
else: # Default for CNNprocessed_array = img_array / 255.0# 模型预测
predictions = model.predict(processed_array)
四、如何运行
- 环境配置:
pip install tensorflow opencv-python matplotlib pyqt5
- 数据集准备:
在项目根目录的上级目录创建一个fruit
文件夹,内部结构如下:/project_folder/your_scripts_foldermain.ui.py... /fruit/train/Apple1.jpg2.jpg.../Banana.../val/Apple.../Banana...
- 训练模型:
直接运行main.ui.py
,在“模型训练”选项卡中点击相应的按钮进行训练。训练好的模型会保存在models
文件夹下。 - 开始使用:
模型训练完毕后,即可在其他选项卡中进行识别和评估。
五、总结与展望
本项目完整地实现了一个从数据处理、模型训练到部署应用的深度学习全流程。通过PyQt5将复杂的功能封装在友好的图形界面下,大大降低了使用门槛。
未来可扩展的方向:
- 模型优化: 尝试更先进的模型(如EfficientNet)或对当前模型进行微调(Fine-tuning)以提高精度。
- 功能扩展: 增加对水果新鲜度、卡路里等信息的识别与展示。
- 部署: 将模型部署到Web端或移动端,提供更广泛的服务。
希望这个项目能对你有所启发,感谢阅读!如果你觉得不错,欢迎点赞、收藏、关注!