目录
Estimator
class Estimator(builtins.object)
__init__(self, model_fn, model_dir=None, config=None, params=None, warm_start_from=None)
model_fn
这个模型给定输入和参数,会返回训练、验证或者预测等所需要的操作节点。 所有的输出(检查点、事件文件等)会写入到 model_dir,或者其子文件夹中。如果 model_dir 为空,则默认为临时目录。
config
参数为 tf.estimator.RunConfig 对象,包含了执行环境的信息。如果没有传递 config,则它会被 Estimator 实例化,使用的是默认配置。
params
包含了超参数。Estimator 只传递超参数,不会检查超参数,因此 params 的结构完全取决于开发者。
其它细节
Estimator 的所有方法都不能被子类覆盖(它的构造方法强制决定的)。子类应该使用 model_fn 来配置母类,或者增添方法来实现特殊的功能。 Estimator 不支持 Eager Execution(eager execution能够使用Python 的debug工具、数据结构与控制流。并且无需使用placeholder、session,计算结果能够立即得出)。
model_fn
features
这是 input_fn 返回的第一项(input_fn 是 train, evaluate 和 predict 的参数)。类型应该是单一的 Tensor 或者 dict。
labels
这是 input_fn 返回的第二项。类型应该是单一的 Tensor 或者 dict。如果 mode 为 ModeKeys.PREDICT,则会默认为 labels=None。如果 model_fn 不接受 mode,model_fn 应该仍然可以处理 labels=None。
mode
可选。指定是训练、验证还是测试。参见 ModeKeys。
params
可选,超参数的 dict。 可以从超参数调整中配置Estimators。
config
可选,配置。如果没有传则为默认值。可以根据 num_ps_replicas 或 model_dir 等配置更新 model_fn。
返回
EstimatorSpec