shape和Dimension的关系

shape=(2,) TensorShape([Dimension(2)])

# b is a tensor which shape is [2, 3], return [2, 3]
b.get_shape().as_list()

tf.shape 是 TensorFlow 中的一个函数,用于获取张量(Tensor)的形状(即维度)。这个函数返回一个 1-D 张量,该张量包含输入张量每个维度的大小。tf.shape 的输出是一个整数类型的张量,其长度等于输入张量的秩(rank,即维度的数量)。

基本用法

import tensorflow as tf

# 创建一个张量
tensor = tf.constant([[1, 2, 3], [4, 5, 6]])

# 获取张量的形状
shape = tf.shape(tensor)

# 打印形状
print(shape)  # 输出: tf.Tensor([2 3], shape=(2,), dtype=int32)

详细说明

  • 输入参数:
  • input: 要查询形状的张量。

  • 返回值:

  • 一个 1-D 张量,表示输入张量的形状。返回的张量类型通常是 int32int64,具体取决于 TensorFlow 的配置。

示例

示例 1: 一维张量

import tensorflow as tf

# 创建一个一维张量
tensor = tf.constant([1, 2, 3, 4])

# 获取张量的形状
shape = tf.shape(tensor)

# 打印形状
print(shape)  # 输出: tf.Tensor([4], shape=(1,), dtype=int32)

示例 2: 二维张量

import tensorflow as tf

# 创建一个二维张量
tensor = tf.constant([[1, 2, 3], [4, 5, 6]])

# 获取张量的形状
shape = tf.shape(tensor)

# 打印形状
print(shape)  # 输出: tf.Tensor([2 3], shape=(2,), dtype=int32)

示例 3: 动态形状

在 TensorFlow 的动态图(Eager Execution)中,tf.shape 也可以用于动态形状的张量。例如,在模型训练过程中,输入张量的形状可能会发生变化。

import tensorflow as tf

# 创建一个占位符
input_tensor = tf.keras.Input(shape=(None, 32))

# 获取张量的形状
shape = tf.shape(input_tensor)

# 打印形状
print(shape)  # 输出: tf.Tensor([None 32], shape=(2,), dtype=int32)

注意事项

  • tf.shape 返回的是一个张量,而不是 Python 列表或元组。如果需要将其转换为 Python 列表,可以使用 shape.numpy().tolist()
  • 对于静态形状,可以使用 tensor.shape 属性,它返回一个 TensorShape 对象,可以直接访问各个维度的大小。
import tensorflow as tf

# 创建一个张量
tensor = tf.constant([[1, 2, 3], [4, 5, 6]])

# 获取张量的静态形状
static_shape = tensor.shape

# 打印静态形状
print(static_shape)  # 输出: (2, 3)

希望这些信息对您有所帮助!如果有更多问题,请随时提问。