目录
一、新类别模型的再训练
1、图片加载,并将数据集划分为训练集、测试集、验证集,比例分别为80%,10%,10%(默认)
2、加载hub某个模型,拉取模型信息,创建图
3、计算所有图片的bottlenecks(特征向量),并缓存
4、新类别模型训练
5、新类别预测模型保存
二、模型预测
1、预测模型加载
2、加载预测图片(图片进行解码和剪裁)
3、多张图片类别预测
原网址:https://www.tensorflow.org/hub/tutorials/image_retraining
预定义-第三方包
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import collections
from datetime import datetime
import hashlib
import os.path
import random
import re
import sys
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
tf.logging.set_verbosity(tf.logging.INFO)
预定义--文件路径
FLAGS.image_dir = 'E:\DataMining\tensorflow\google-ImageClassification\flower_photos'
'''
Path to folders of labeled images.
'''
FLAGS.output_graph = 'E:\DataMining\tensorflow\google-ImageClassification\flower\graph\output_graph.pb'
'''
Where to save the trained graph.
'''
FLAGS.intermediate_output_graphs_dir = 'E:\DataMining\tensorflow\google-ImageClassification\flower\intermediate_graph\'
'''
Where to save the intermediate graphs.
'''
FLAGS.intermediate_store_frequency = 0
"""
How many steps to store intermediate graph. If "0" then will not
store.
"""
FLAGS.output_labels = 'E:\DataMining\tensorflow\google-ImageClassification\flower\labels\output_labels.txt'
'''
Where to save the trained graph's labels.
'''
FLAGS.summaries_dir = 'E:\DataMining\tensorflow\google-ImageClassification\flower\retrain_logs'
'''
Where to save summary logs for TensorBoard.
'''
FLAGS.how_many_training_steps = 4000
'''How many training steps to run before ending.'''
FLAGS.learning_rate = t=0.01
'''How large a learning rate to use when training.'''
FLAGS.testing_percentage = 10
'''What percentage of images to use as a test set.'''
FLAGS.validation_percentage = 10
'''What percentage of images to use as a validation set.'''
FLAGS.eval_step_interval = 10
'''How often to evaluate the training results.'''
FLAGS.train_batch_size = 100
'''How many images to train on at a time.'''
FLAGS.test_batch_size = -1
"""
How many images to test on. This test set is only used once, to evaluate
the final accuracy of the model after training completes.
A value of -1 causes the entire test set to be used, which leads to more
stable results across runs.
"""
FLAGS.validation_batch_size = 100
"""
How many images to use in an evaluation batch. This validation set is
used much more often than the test set, and is an early indicator of how
accurate the model is during training.
A value of -1 causes the entire validation set to be used, which leads to
more stable results across training iterations, but may be slower on large
training sets.
"""
FLAGS.print_misclassified_test_images = False
"""
Whether to print out a list of all misclassified test images.
"""
FLAGS.bottleneck_dir = 'E:\DataMining\tensorflow\google-ImageClassification\flower\bottleneck'
'Path to cache bottleneck layer values as files.'
FLAGS.final_tensor_name = 'final_result'
"""
The name of the output classification layer in the retrained graph.
"""
FLAGS.flip_left_right = False
"""
Whether to randomly flip half of the training images horizontally.
"""
FLAGS.random_crop = 0
"""
A percentage determining how much of a margin to randomly crop off the
training images.
"""
FLAGS.random_scale = 0
"""
A percentage determining how much to randomly scale up the size of the
training images by.
"""
FLAGS.random_brightness = 0
"""
A percentage determining how much to randomly multiply the training image
input pixels up or down by.
"""
FLAGS.tfhub_module = 'https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1'
""" Which TensorFlow Hub module to use. For more options, search https://tfhub.dev for image feature vector modules. """
FLAGS.saved_model_dir = 'E:\DataMining\tensorflow\google-ImageClassification\flower\exportedGraph'
""" Where to save the exported graph."""
预定义--全局变量
# The location where variable checkpoints will be stored.
CHECKPOINT_NAME = 'E:\DataMining\tensorflow\google-ImageClassification\flower\retrain_checkpoint\'
# A module is understood as instrumented for quantization with TF-Lite
# if it contains any of these ops.
FAKE_QUANT_OPS = ('FakeQuantWithMinMaxVars','FakeQuantWithMinMaxVarsPerChannel')
def create_image_lists(image_dir, testing_percentage, validation_percentage):
"""Builds a list of training images from the file system.
Analyzes the sub folders in the image directory, splits them into stable
training, testing, and validation sets, and returns a data structure
describing the lists of images for each label and their paths.
Args:
image_dir: String path to a folder containing subfolders of images.
testing_percentage: Integer percentage of the images to reserve for tests.
validation_percentage: Integer percentage of images reserved for validation.
Returns:
An OrderedDict containing an entry for each label subfolder, with images
split into training, testing, and validation sets within each label.
The order of items defines the class indices.
"""
if not tf.gfile.Exists(image_dir):
tf.logging.error("Image directory '" + image_dir + "' not found.")
return None
result = collections.OrderedDict()
sub_dirs = sorted(x[0] for x in tf.gfile.Walk(image_dir))
# The root directory comes first, so skip it.
is_root_dir = True
for sub_dir in sub_dirs:
if is_root_dir:
is_root_dir = False
continue
extensions = sorted(set(os.path.normcase(ext) for ext in ['JPEG', 'JPG', 'jpeg', 'jpg', 'png']))
file_list = []
dir_name = os.path.basename(sub_dir)
if dir_name == image_dir:
continue
#tf.logging.info("Looking for images in '" + dir_name + "'")
for extension in extensions:
file_glob = os.path.join(image_dir, dir_name, '*.' + extension)
file_list.extend(tf.gfile.Glob(file_glob))
if not file_list:
tf.logging.warning('No files found')
continue
if len(file_list) < 20:
tf.logging.warning( 'WARNING: Folder has less than 20 images, which may cause issues.')
elif len(file_list) > MAX_NUM_IMAGES_PER_CLASS:
tf.logging.warning('WARNING: Folder {} has more than {} images. Some images will never be selected.'.format(dir_name, MAX_NUM_IMAGES_PER_CLASS))
label_name = re.sub(r'[^a-z0-9]+', ' ', dir_name.lower())
training_images = []
testing_images = []
validation_images = []
for file_name in file_list:
base_name = os.path.basename(file_name)
hash_name = re.sub(r'_nohash_.*$', '', file_name)
hash_name_hashed = hashlib.sha1(tf.compat.as_bytes(hash_name)).hexdigest()
percentage_hash = ((int(hash_name_hashed, 16) %(MAX_NUM_IMAGES_PER_CLASS + 1)) *(100.0 / MAX_NUM_IMAGES_PER_CLASS))
if percentage_hash < validation_percentage:
validation_images.append(base_name)
elif percentage_hash < (testing_percentage + validation_percentage):
testing_images.append(base_name)
else:
training_images.append(base_name)
result[label_name] = {
'dir': dir_name,
'training': training_images,
'testing': testing_images,
'validation': validation_images,
}
return result
def main(_):
#获取商品图片
image_lists = create_image_lists(FLAGS.image_dir, FLAGS.testing_percentage, FLAGS.validation_percentage)
class_count = len(image_lists.keys())
image_lists数据格式为:
([('daisy',
{'dir': 'daisy',
'testing':['100080576_f52e8ee070_n.jpg','10172379554_b296050f82_n.jpg',...],
'training':['10140303196_b88d3d6cec.jpg',...],
'validation':['102841525_bd6628ae3c.jpg',]}),
('dandelion',
{'dir': 'dandelion',
'testing': ['10294487385_92a0676c7d_m.jpg',...],
'training':['10140303196_b88d3d6cec.jpg',...],
'validation':['102841525_bd6628ae3c.jpg',]}),
#加载tensorflow中的某个模型,并拉取模型信息
def create_module_graph(module_spec):
"""Creates a graph and loads Hub Module into it.
Args:
module_spec: the hub.ModuleSpec for the image module being used.
Returns:
graph: the tf.Graph that was created.
bottleneck_tensor: the bottleneck values output by the module.
resized_input_tensor: the input images, resized as expected by the module.
wants_quantization: a boolean, whether the module has been instrumented
with fake quantization ops.
"""
height, width = hub.get_expected_image_size(module_spec)
with tf.Graph().as_default() as graph:
resized_input_tensor = tf.placeholder(tf.float32, [None, height, width, 3])
m = hub.Module(module_spec)
bottleneck_tensor = m(resized_input_tensor)
wants_quantization = any(node.op in FAKE_QUANT_OPS for node in graph.as_graph_def().node)
return graph, bottleneck_tensor, resized_input_tensor, wants_quantization
#对图片进行解码和调整大小
def add_jpeg_decoding(module_spec):
"""Adds operations that perform JPEG decoding and resiz
相关知识
花卉识别(tensorflow)
Tensorflow训练鸢尾花数据集
TensorFlow学习记录(八)
使用TensorFlow给花朵分类
基于tensorflow的花卉识别
TensorFlow入门
基于TensorFlow Lite实现的Android花卉识别应用
深度学习之基于Tensorflow卷积神经网络花卉识别系统
Keras花卉分类全流程(预处理+训练+预测)
基于TensorFlow的CNN卷积网络模型花卉分类(1)
网址: [tensorflow]图片新类别再训练 https://m.huajiangbk.com/newsview343176.html
上一篇: 构建、训练和部署 102 种花卉 |
下一篇: 数据集划分,Oxford Flo |