本文主要描述如何使用 Google 开源的目标检测 API 来训练目标检测器,内容包括:安装 TensorFlow/Object Detection API 和使用 TensorFlow/Object Detection API 训练自己的目标检测器。
Google 开源的目标检测项目 object_detection 位于与 tensorflow 独立的项目 models(独立指的是:在安装 tensorflow 的时候并没有安装 models 部分)内:models/research/object_detection。models 部分的 GitHub 主页为:
https://github.com/tensorflow/models
要使用 models 部分内的目标检测功能 object_detection,需要用户手动安装 object_detection。下面为详细的安装步骤:
1. 安装依赖项 matplotlib,pillow,lxml 等使用 pip/pip3 直接安装:
$ sudo pip/pip3 install matplotlib pillow lxml1
其中如果安装 lxml 不成功,可使用
$ sudo apt-get install python-lxml python3-lxml1
安装。
2. 安装编译工具$ sudo apt install protobuf-compiler $ sudo apt-get install python-tk $ sudo apt-get install python3-tk123 3. 克隆 TensorFlow models 项目
使用 git 克隆 models 部分到本地,在终端输入指令:
$ git clone https://github.com/tensorflow/models.git1
克隆完成后,会在终端当前目录出现 models 的文件夹。要使用 git(分布式版本控制系统),首先得安装 git:$ sudo apt-get install git。
4. 使用 protoc 编译在 models/research 目录下的终端执行:
$ protoc object_detection/protos/*.proto --python_out=.1
将 object_detection/protos/ 文件下的以 .proto 为后缀的文件编译为 .py 文件输出。
5. 配置环境变量在 .bashrc 文件中加入环境变量。首先打开 .bashrc 文件:
$ sudo gedit ~/.bashrc1
然后在文件末尾加入新行:
export PYTHONPATH=$PYTHONPATH:/.../models/research:/.../modes/research/slim1
其中省略号所在的两个目录需要填写为 models/research 文件夹、models/research/slim 文件夹的完整目录。保存之后执行如下指令:
$ source ~/.bashrc1
让改动立即生效。
6. 测试是否安装成功在 models/research 文件下执行:
$ python/python3 object_detection/builders/model_builder_test.py1
如果返回 OK,表示安装成功。
成功安装好 TensorFlow Object Detection API 之后,就可以按照 models/research/object_detection 文件夹下的演示文件 object_detection_tutorial.ipynb 来查看 Google 自带的目标检测的检测效果。其中,Google 自己训练好后的目标检测器都放在:
https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
可以自己下载这些模型,一一查看检测效果。以下,假设你把某些预训练模型下载好了,放在models/ research/ object_detection 的某个文件夹下,比如自定义文件夹 pretrained_models。
要训练自己的模型,除了使用 Google 自带的预训练模型之外,最关键的是需要准备自己的训练数据。
以下,详细列出训练过程(后续部分文章将详细介绍一些目标检测算法):
1. 准备标注工具和文件格式转化工具图像标注可以使用标注工具 labelImg,直接使用
$ sudo pip install labelImg1
安装(当前好像只支持Python2.7)。另外,在此之前,需要安装它的依赖项 pyqt4:
$ sudo apt-get install pyqt4-dev-tools1
(另一依赖项 lxml 前面已安装)。要使用 labelImg,只需要在终端输入 labelImg 即可。
为了方便后续数据格式转化,还需要准备两个文件格式转化工具:xml_to_csv.py 和 generate_tfrecord.py,它们的代码分别列举如下(它们可以从资料 [1] 中 GitHub 项目源代码链接中下载。其中为了方便一般化使用,我已经修改 generate_tfrecord.py 的部分内容使得可以自定义图像路径和输入 .csv 文件、输出 .record 文件路径,以及 6 中的 xxx_label_map.pbtxt 文件路径):
(1) xml_to_csv.py 文件源码:
import os import glob import pandas as pd import xml.etree.ElementTree as ET def 1234567
相关知识
[tensorflow]图片新类别再训练
TensorFlow入门
TensorFlow学习记录(八)
tensorflow识别花朵
Tensorflow训练鸢尾花数据集
花卉识别(tensorflow)
深度学习入门——基于TensorFlow的鸢尾花分类实现(TensorFlow
基于tensorflow的花卉识别
构建、训练和部署 102 种花卉类型的真实花卉分类器
基于TensorFlow Lite实现的Android花卉识别应用
网址: TensorFlow 训练自己的目标检测器 https://m.huajiangbk.com/newsview504253.html
上一篇: 植物碳排放在线监测系统设备 |
下一篇: 家庭健康监测,创新的家庭健康评估 |