一、Estimator
1、介绍
- 编程堆栈

- Estimator:代表一个完整的模型。- Estimator API提供一些方法来训练模型、判断模型的准确率并生成预测。
- 数据集:构建数据输入管道。Dataset API提供一些方法来加载和操作数据,并将数据馈送到您的模型中。Dataset API与Estimator API合作无间
2、鸢尾花进行分类
- 数据集介绍:4个属性,分为3类:
| 花萼长度 | 花萼宽度 | 花瓣长度 | 花瓣宽度 | 品种(标签) | 
|---|---|---|---|---|
| 5.1 | 3.3 | 1.7 | 0.5 | 0(山鸢尾) | 
| 5.0 | 2.3 | 3.3 | 1.0 | 1(变色鸢尾) | 
| 6.4 | 2.8 | 5.6 | 2.2 | 2(维吉尼亚鸢尾) | 
- 网络模型

3、实现
- Estimator是- TensorFlow对完整模型的高级表示。它会处理初始化、日志记录、保存和恢复等细节部分,并具有很多其他功能,以便您可以专注于模型。
3.1 预创建模型
- 完整代码:点击查看
- 导入包和参数配置
| 
 | 
 | 
- 构建模型- 特征列:feature_column:特征列是一个对象,用于说明模型应该如何使用特征字典中的原始输入数据。在构建Estimator模型时,您会向其传递一个特征列的列表,其中包含您希望模型使用的每个特征。tf.feature_column模块提供很多用于向模型表示数据的选项。- 对于鸢尾花问题,4 个原始特征是数值,因此我们会构建一个特征列的列表,以告知 Estimator模型将这 4 个特征都表示为 32 位浮点值。
 
- 对于鸢尾花问题,4 个原始特征是数值,因此我们会构建一个特征列的列表,以告知 
- 实例化 Estimator: 使用的是预创建模型cls = tf.estimator.DNNClassifier()模型
- 训练模型 cls.train(input_fn, hooks=None, steps=None, max_steps=None, saving_listeners=None):- input_fn指定输入的函数,包含- (features, labels)的- tf.data.Dataset类型的数据
- steps参数告知方法在训练多少步后停止训练。
 
- 评估经过训练的模型:eval_res = cls.evaluate(input_fn, steps=None, hooks=None, checkpoint_path=None, name=None)- 输入和训练数据一致
- 返回的有{'accuracy': 1.0, 'loss': 3.936471, 'average_loss': 0.1312157, 'global_step': 100}
 
- 预测: predictions = cls.predict(input_fn, predict_keys=None, hooks=None, checkpoint_path=None, yield_single_examples=True)- 输入数据为 batch_size的测试数据,不包含label,返回生成器结果
 
- 输入数据为 
 
- 特征列:
| 
 | 
 | 
- 运行函数 - tf.app.run(main=main)会先解析命令行参数,然后执行- main函数123if __name__ == "__main__":tf.logging.set_verbosity(tf.logging.INFO)tf.app.run(main=main)
 
- 保存和加载模型 - 指定模型地址即可:model_dir,在第一次训练时会保存模型 - 如果未在 Estimator的构造函数中指定model_dir,则Estimator会将检查点文件写入由Python的tempfile.mkdtemp函数选择的临时目录中,可以print(classifier.model_dir)查看
 
- 如果未在 
- 检查点频率:- 默认- 每 10分钟(600秒)写入一个检查点。
- 在 train方法开始(第一次迭代)和完成(最后一次迭代)时写入一个检查点。
- 只在目录中保留 5个最近写入的检查点。
 
- 每 
- 自己配置:123456my_checkpoint_config = tf.estimator.RunConfig(save_checkpoints_secs = 20*60, # 每20分钟保存一次keep_checkpoint_max = 10) # 保存10个最近的检查点cls = tf.estimator.DNNClassifier(hidden_units=[10,10], feature_columns=my_feature_columns,n_classes=3,model_dir='model/',config=my_checkpoint_config)
 
- 默认
 
- 指定模型地址即可:
- 加载模型- 不需要改动,一旦存在检查点,TensorFlow就会在您每次调用train()、evaluate()或predict()时重建模型。 
 
- 不需要改动,一旦存在检查点,
 
- 加载模型
3.2 自定义模型
- 完整代码:点击查看
- 预创建的 Estimator是tf.estimator.Estimator基类的子类,而自定义Estimator是tf.estimator.Estimator的实例 
- 创建模型 - 模型函数(即 model_fn)会实现机器学习算法
- params参数会传递给自己实现的模型123456cls = tf.estimator.Estimator(model_fn=my_model,params={'feature_columns': my_feature_columns,'hidden_units': [10, 10],'num_classes': 3})
 
- 模型函数(即 
- 自定义 - my_model函数:- 输入层指定输入的数据和对应的feature columns
- 隐藏层通过tf.layers.dense()创建
- 通过mode来判断是训练、评价还是预测操作,返回必须是tf.estimator.EstimatorSpec对象 
 
- 输入层指定输入的数据和对应的
| 
 | 
 | 
- 在 TensorBoard中查看自定义Estimator的训练结果。(预定义的模型结果展示更丰富一些)- tensorboard --logdir=PATH
- global_step/sec:这是一个性能指标,显示我们在进行模型训练时每秒处理的批次数(梯度更新)。 
- loss:所报告的损失。 
- accuracy:准确率由下列两行记录:- eval_metric_ops={‘my_accuracy’: accuracy})(评估期间)。
- tf.summary.scalar(‘accuracy’, accuracy1)(训练期间)。 
 
 
二、Dataset
- tf.data模块包含一系列类,可让轻松地加载数据、操作数据并通过管道将数据传送到模型中。
1、基本输入
- 从数组中提取接片,上面用到的代码 - feature:特征数据,为- feature-name: array的字典或者- DataFrame
- labels: 标签数组
- from_tensor_slices会按第一个维度进行切片,比如输入为- [6000, 28, 28]维度的数据,切片后返回- 6000个- 28, 28的- Dataset对象
- shuffle方法使用一个固定大小的缓冲区,在条目经过时随机化处理条目。在这种情况下,- buffer_size大于- Dataset中样本的数量,确保数据完全被随机化处理。
- repeat方法会在结束时重启- Dataset。要限制周期数量,请设置- count参数。
- batch方法会收集大量样本并将它们堆叠起来以创建批次。这为批次的形状增加了一个维度。新的维度将添加为第一个维度。1234567def train_input_fn(features, labels, batch_size):"""训练集输入函数"""dataset = tf.data.Dataset.from_tensor_slices((dict(features,), labels)) # 转化为Datasetdataset = dataset.shuffle(buffer_size=1000).repeat().batch(batch_size) # Shuffle, batchreturn dataset
 
2、读取CSV文件
- 代码
- 处理一行数据, - line: tf.string类型1234567CSV_TYPES = [[0.0], [0.0], [0.0], [0.0], [0]]def _parse_line(line):'''解析一行数据'''field = tf.decode_csv(line, record_defaults=CSV_TYPES)features = dict(zip(CSV_COLUMN_NAMES, field))labels = features.pop("Species")return features, labels
- 处理 - text文件,得到- dataset- 读取文本类型为:<SkipDataset shapes: (), types: tf.string>
- 然后使用map函数,每个对象处理 123456def csv_input_fn(csv_path, batch_size):'''csv文件输入函数'''dataset = tf.data.TextLineDataset(csv_path).skip(1) # 跳过第一行dataset = dataset.map(_parse_line) # 应用map函数处理dataset中的每一个元素dataset = dataset.shuffle(1000).repeat().batch(batch_size)return dataset 123456def csv_input_fn(csv_path, batch_size):'''csv文件输入函数'''dataset = tf.data.TextLineDataset(csv_path).skip(1) # 跳过第一行dataset = dataset.map(_parse_line) # 应用map函数处理dataset中的每一个元素dataset = dataset.shuffle(1000).repeat().batch(batch_size)return dataset
 
- 读取文本类型为:
Reference
- https://tensorflow.google.cn/get_started/get_started_for_beginners?hl=zh-cn
- https://tensorflow.google.cn/get_started/premade_estimators?hl=zh-cn
- https://github.com/tensorflow/models/blob/master/samples/core/get_started/iris_data.py
- 本文链接: http://lawlite.me/2018/05/31/Tensorflow高级API/
- 
      版权声明: 
      本博客所有文章除特别声明外,均采用 CC BY-NC-SA 3.0 许可协议 。转载请注明出处! 。转载请注明出处!
 
		 
                      