您的位置:

TensorFlow中的切片操作——tf.slice

一、概述

切片操作是在TensorFlow中非常常见的一种操作,tf.slice函数就是专门用来进行切片操作的函数。

tf.slice函数的作用是从一个Tensor中提取出一部分数据,作为一个新的Tensor返回。

二、函数参数

tf.slice函数的函数参数非常简洁明了,在这里我们将分别对其中的三个参数进行介绍。

1. input_tensor

该参数表示输入的Tensor,可以是一个常量,也可以是一个变量。

2. begin

该参数表示开始切片的位置,在这里我们可以将它理解为一个坐标。

begin的数据类型必须是一个长度与input_tensor一样的一维数组,数组中的每个元素代表了一个维度上的起始位置。

3. size

该参数表示切片的大小,也可以将其理解为一个区域。

size的数据类型必须是一个长度与input_tensor一样的一维数组,数组中的每个元素代表了一个维度上的切片大小。

三、代码示例

下面我们将通过一些具体的例子来展示tf.slice函数的使用。所有的代码示例都可以在TensorFlow1.15版本下运行。

1. 示例1

首先我们来看一个简单的例子,假设有一个形状为[2,2,2]的Tensor a,我们要从其中取出第一个维度为0,第二个维度为1,第三个维度在前两个维度的基础上取0和1两个值的数据,代码如下:

import tensorflow as tf

a = tf.constant([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
b = tf.slice(a, [0, 1, 0], [1, 1, 2])

sess = tf.Session()
print(sess.run(b))

# 输出结果为:[[[3 4]]]

其中的[0, 1, 0]表示从第一个维度开始取第0个元素,第二个维度开始取第1个元素,第三个维度开始取第0个元素;[1, 1, 2]表示第一个维度上取1个元素,第二个维度上取1个元素,第三个维度上取2个元素。

2. 示例2

接着我们来看一个稍微复杂一些的例子,假设有一个形状为[2,2,2]的Tensor b,我们要从其中取出第1和第2个维度全部取出来,第0个维度在前两个维度的基础上分别取0和1两个值,代码如下:

import tensorflow as tf

b = tf.constant([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
c = tf.slice(b, [0, 0, 0], [2, 2, 2])

sess = tf.Session()
print(sess.run(c))

# 输出结果为:[[[1 2] [3 4]] [[5 6] [7 8]]]

其中的[0, 0, 0]表示从第一个维度开始取第0个元素,第二个维度开始取第0个元素,第三个维度开始取第0个元素;[2, 2, 2]表示第一个维度上取2个元素,第二个维度上取2个元素,第三个维度上取2个元素。

3. 示例3

最后我们来看一个比较灵活的例子,假设有一个形状为[2,3,4]的Tensor d,我们只需要取出第2个维度,而且第0个维度上的取值为0,第1个维度上的取值为1,代码如下:

import tensorflow as tf

d = tf.constant([[[1,  2,  3,  4], [5,  6,  7,  8], [9, 10, 11, 12]],
                 [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]])
e = tf.slice(d, [0, 1, 0], [2, 1, 4])

sess = tf.Session()
print(sess.run(e))

# 输出结果为:[[[ 5  6  7  8]] [[17 18 19 20]]]

其中的[0, 1, 0]表示从第一个维度开始取第0个元素,第二个维度开始取第1个元素,第三个维度开始取第0个元素;[2, 1, 4]表示第一个维度上取2个元素,第二个维度上取1个元素,第三个维度上取4个元素。