首页 > 分享 > [tensorflow]图片新类别再训练

[tensorflow]图片新类别再训练

目录

一、新类别模型的再训练

1、图片加载,并将数据集划分为训练集、测试集、验证集,比例分别为80%,10%,10%(默认)

2、加载hub某个模型,拉取模型信息,创建图

3、计算所有图片的bottlenecks(特征向量),并缓存

4、新类别模型训练

5、新类别预测模型保存

二、模型预测

1、预测模型加载

2、加载预测图片(图片进行解码和剪裁)

3、多张图片类别预测

原网址:https://www.tensorflow.org/hub/tutorials/image_retraining

一、新类别模型的再训练

预定义-第三方包

from __future__ import absolute_import

from __future__ import division

from __future__ import print_function

import argparse

import collections

from datetime import datetime

import hashlib

import os.path

import random

import re

import sys

import numpy as np

import tensorflow as tf

import tensorflow_hub as hub

tf.logging.set_verbosity(tf.logging.INFO)

预定义--文件路径 

FLAGS.image_dir = 'E:\DataMining\tensorflow\google-ImageClassification\flower_photos'

'''

Path to folders of labeled images.

'''

FLAGS.output_graph = 'E:\DataMining\tensorflow\google-ImageClassification\flower\graph\output_graph.pb'

'''

Where to save the trained graph.

'''

FLAGS.intermediate_output_graphs_dir = 'E:\DataMining\tensorflow\google-ImageClassification\flower\intermediate_graph\'

'''

Where to save the intermediate graphs.

'''

FLAGS.intermediate_store_frequency = 0

"""

How many steps to store intermediate graph. If "0" then will not

store.

"""

FLAGS.output_labels = 'E:\DataMining\tensorflow\google-ImageClassification\flower\labels\output_labels.txt'

'''

Where to save the trained graph's labels.

'''

FLAGS.summaries_dir = 'E:\DataMining\tensorflow\google-ImageClassification\flower\retrain_logs'

'''

Where to save summary logs for TensorBoard.

'''

FLAGS.how_many_training_steps = 4000

'''How many training steps to run before ending.'''

FLAGS.learning_rate = t=0.01

'''How large a learning rate to use when training.'''

FLAGS.testing_percentage = 10

'''What percentage of images to use as a test set.'''

FLAGS.validation_percentage = 10

'''What percentage of images to use as a validation set.'''

FLAGS.eval_step_interval = 10

'''How often to evaluate the training results.'''

FLAGS.train_batch_size = 100

'''How many images to train on at a time.'''

FLAGS.test_batch_size = -1

"""

How many images to test on. This test set is only used once, to evaluate

the final accuracy of the model after training completes.

A value of -1 causes the entire test set to be used, which leads to more

stable results across runs.

"""

FLAGS.validation_batch_size = 100

"""

How many images to use in an evaluation batch. This validation set is

used much more often than the test set, and is an early indicator of how

accurate the model is during training.

A value of -1 causes the entire validation set to be used, which leads to

more stable results across training iterations, but may be slower on large

training sets.

"""

FLAGS.print_misclassified_test_images = False

"""

Whether to print out a list of all misclassified test images.

"""

FLAGS.bottleneck_dir = 'E:\DataMining\tensorflow\google-ImageClassification\flower\bottleneck'

'Path to cache bottleneck layer values as files.'

FLAGS.final_tensor_name = 'final_result'

"""

The name of the output classification layer in the retrained graph.

"""

FLAGS.flip_left_right = False

"""

Whether to randomly flip half of the training images horizontally.

"""

FLAGS.random_crop = 0

"""

A percentage determining how much of a margin to randomly crop off the

training images.

"""

FLAGS.random_scale = 0

"""

A percentage determining how much to randomly scale up the size of the

training images by.

"""

FLAGS.random_brightness = 0

"""

A percentage determining how much to randomly multiply the training image

input pixels up or down by.

"""

FLAGS.tfhub_module = 'https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1'

""" Which TensorFlow Hub module to use. For more options, search https://tfhub.dev for image feature vector modules. """

FLAGS.saved_model_dir = 'E:\DataMining\tensorflow\google-ImageClassification\flower\exportedGraph'

""" Where to save the exported graph."""

预定义--全局变量

# The location where variable checkpoints will be stored.

CHECKPOINT_NAME = 'E:\DataMining\tensorflow\google-ImageClassification\flower\retrain_checkpoint\'

# A module is understood as instrumented for quantization with TF-Lite

# if it contains any of these ops.

FAKE_QUANT_OPS = ('FakeQuantWithMinMaxVars','FakeQuantWithMinMaxVarsPerChannel')

1、图片加载,并将数据集划分为训练集、测试集、验证集,比例分别为80%,10%,10%(默认)

def create_image_lists(image_dir, testing_percentage, validation_percentage):

"""Builds a list of training images from the file system.

Analyzes the sub folders in the image directory, splits them into stable

training, testing, and validation sets, and returns a data structure

describing the lists of images for each label and their paths.

Args:

image_dir: String path to a folder containing subfolders of images.

testing_percentage: Integer percentage of the images to reserve for tests.

validation_percentage: Integer percentage of images reserved for validation.

Returns:

An OrderedDict containing an entry for each label subfolder, with images

split into training, testing, and validation sets within each label.

The order of items defines the class indices.

"""

if not tf.gfile.Exists(image_dir):

tf.logging.error("Image directory '" + image_dir + "' not found.")

return None

result = collections.OrderedDict()

sub_dirs = sorted(x[0] for x in tf.gfile.Walk(image_dir))

# The root directory comes first, so skip it.

is_root_dir = True

for sub_dir in sub_dirs:

if is_root_dir:

is_root_dir = False

continue

extensions = sorted(set(os.path.normcase(ext) for ext in ['JPEG', 'JPG', 'jpeg', 'jpg', 'png']))

file_list = []

dir_name = os.path.basename(sub_dir)

if dir_name == image_dir:

continue

#tf.logging.info("Looking for images in '" + dir_name + "'")

for extension in extensions:

file_glob = os.path.join(image_dir, dir_name, '*.' + extension)

file_list.extend(tf.gfile.Glob(file_glob))

if not file_list:

tf.logging.warning('No files found')

continue

if len(file_list) < 20:

tf.logging.warning( 'WARNING: Folder has less than 20 images, which may cause issues.')

elif len(file_list) > MAX_NUM_IMAGES_PER_CLASS:

tf.logging.warning('WARNING: Folder {} has more than {} images. Some images will never be selected.'.format(dir_name, MAX_NUM_IMAGES_PER_CLASS))

label_name = re.sub(r'[^a-z0-9]+', ' ', dir_name.lower())

training_images = []

testing_images = []

validation_images = []

for file_name in file_list:

base_name = os.path.basename(file_name)

hash_name = re.sub(r'_nohash_.*$', '', file_name)

hash_name_hashed = hashlib.sha1(tf.compat.as_bytes(hash_name)).hexdigest()

percentage_hash = ((int(hash_name_hashed, 16) %(MAX_NUM_IMAGES_PER_CLASS + 1)) *(100.0 / MAX_NUM_IMAGES_PER_CLASS))

if percentage_hash < validation_percentage:

validation_images.append(base_name)

elif percentage_hash < (testing_percentage + validation_percentage):

testing_images.append(base_name)

else:

training_images.append(base_name)

result[label_name] = {

'dir': dir_name,

'training': training_images,

'testing': testing_images,

'validation': validation_images,

}

return result

def main(_):

#获取商品图片

image_lists = create_image_lists(FLAGS.image_dir, FLAGS.testing_percentage, FLAGS.validation_percentage)

class_count = len(image_lists.keys())

image_lists数据格式为:

([('daisy',

{'dir': 'daisy',

'testing':['100080576_f52e8ee070_n.jpg','10172379554_b296050f82_n.jpg',...],

'training':['10140303196_b88d3d6cec.jpg',...],

'validation':['102841525_bd6628ae3c.jpg',]}),

('dandelion',

{'dir': 'dandelion',

'testing': ['10294487385_92a0676c7d_m.jpg',...],

'training':['10140303196_b88d3d6cec.jpg',...],

'validation':['102841525_bd6628ae3c.jpg',]}),

2、加载hub某个模型,拉取模型信息,创建图

#加载tensorflow中的某个模型,并拉取模型信息

def create_module_graph(module_spec):

"""Creates a graph and loads Hub Module into it.

Args:

module_spec: the hub.ModuleSpec for the image module being used.

Returns:

graph: the tf.Graph that was created.

bottleneck_tensor: the bottleneck values output by the module.

resized_input_tensor: the input images, resized as expected by the module.

wants_quantization: a boolean, whether the module has been instrumented

with fake quantization ops.

"""

height, width = hub.get_expected_image_size(module_spec)

with tf.Graph().as_default() as graph:

resized_input_tensor = tf.placeholder(tf.float32, [None, height, width, 3])

m = hub.Module(module_spec)

bottleneck_tensor = m(resized_input_tensor)

wants_quantization = any(node.op in FAKE_QUANT_OPS for node in graph.as_graph_def().node)

return graph, bottleneck_tensor, resized_input_tensor, wants_quantization

#对图片进行解码和调整大小

def add_jpeg_decoding(module_spec):

"""Adds operations that perform JPEG decoding and resiz

相关知识

花卉识别(tensorflow)
Tensorflow训练鸢尾花数据集
TensorFlow学习记录(八)
使用TensorFlow给花朵分类
基于tensorflow的花卉识别
TensorFlow入门
基于TensorFlow Lite实现的Android花卉识别应用
深度学习之基于Tensorflow卷积神经网络花卉识别系统
Keras花卉分类全流程(预处理+训练+预测)
基于TensorFlow的CNN卷积网络模型花卉分类(1)

网址: [tensorflow]图片新类别再训练 https://m.huajiangbk.com/newsview343176.html

所属分类:花卉
上一篇: 构建、训练和部署 102 种花卉
下一篇: 数据集划分,Oxford Flo