tf.feature_column.weighted_categorical_column
是 TensorFlow 中的一个功能列(Feature Column),它用于将权重值应用于分类特征列(Categorical Column)。这个功能在处理具有不同权重的分类特征时非常有用,例如在文本分类中,不同的单词可能具有不同的权重。
函数定义
tf.feature_column.weighted_categorical_column
的定义如下:
tf.feature_column.weighted_categorical_column(
categorical_column,
weight_feature_key,
dtype=tf.dtypes.float32
)
-
categorical_column
:由categorical_column_with_*
函数创建的CategoricalColumn
。 -
weight_feature_key
:权重值的字符串键。 -
dtype
:权重类型,例如tf.float32
。仅支持浮点和整数权重。
返回值
该函数返回一个 CategoricalColumn
,由两个稀疏特征组成:一个代表 ID,另一个代表该示例中 ID 特征的权重(值)。
抛出异常
ValueError
:如果dtype
不能转换为浮点数。
使用示例
假设你有一个文本文档,你将其表示为词频的集合,你可以提供两个并行稀疏输入特征('terms' 和 'frequencies')。以下是如何使用 tf.feature_column.weighted_categorical_column
的一个示例:
categorical_column = tf.feature_column.categorical_column_with_hash_bucket(
column_name='terms', hash_bucket_size=1000)
weighted_column = tf.feature_column.weighted_categorical_column(
categorical_column=categorical_column, weight_feature_key='frequencies')
columns = [weighted_column, ...]
features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
linear_prediction, _, _ = linear_model(features, columns)
在这个例子中,输入字典包含键 'terms' 的 SparseTensor
和键 'frequencies' 的 SparseTensor
。这两个张量必须具有相同的索引和密集形状。