tf.feature_column.weighted_categorical_column 是 TensorFlow 中的一个功能列(Feature Column),它用于将权重值应用于分类特征列(Categorical Column)。这个功能在处理具有不同权重的分类特征时非常有用,例如在文本分类中,不同的单词可能具有不同的权重。

函数定义

tf.feature_column.weighted_categorical_column 的定义如下:

tf.feature_column.weighted_categorical_column(
    categorical_column,
    weight_feature_key,
    dtype=tf.dtypes.float32
)
  • categorical_column:由 categorical_column_with_* 函数创建的 CategoricalColumn

  • weight_feature_key:权重值的字符串键。

  • dtype:权重类型,例如 tf.float32。仅支持浮点和整数权重。

返回值

该函数返回一个 CategoricalColumn,由两个稀疏特征组成:一个代表 ID,另一个代表该示例中 ID 特征的权重(值)。

抛出异常

  • ValueError:如果 dtype 不能转换为浮点数。

使用示例

假设你有一个文本文档,你将其表示为词频的集合,你可以提供两个并行稀疏输入特征('terms' 和 'frequencies')。以下是如何使用 tf.feature_column.weighted_categorical_column 的一个示例:

categorical_column = tf.feature_column.categorical_column_with_hash_bucket(
    column_name='terms', hash_bucket_size=1000)
weighted_column = tf.feature_column.weighted_categorical_column(
    categorical_column=categorical_column, weight_feature_key='frequencies')
columns = [weighted_column, ...]
features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
linear_prediction, _, _ = linear_model(features, columns)
在这个例子中,输入字典包含键 'terms' 的 SparseTensor 和键 'frequencies' 的 SparseTensor。这两个张量必须具有相同的索引和密集形状。