牛骨文教育服务平台(让学习变的简单)
博文笔记

tf.split (API r1.3)

创建时间:2017-11-10 投稿人: 浏览次数:1004

tf.split (API r1.3)

1. tf.split

split(
    value,
    num_or_size_splits,
    axis=0,
    num=None,
    name="split"
)


Defined in tensorflow/python/ops/array_ops.py.

See the guide: Tensor Transformations > Slicing and Joining

Splits a tensor into sub tensors.

If num_or_size_splits is an integer type, num_split, then splits value along dimension axis into num_split smaller tensors. Requires that num_split evenly divides value.shape[axis].


If num_or_size_splits is not an integer type, it is presumed to be a Tensor size_splits, then splits value into len(size_splits) pieces. The shape of the i-th piece has the same size as the value except along dimension axis where the size is size_splits[i].

For example:

# "value" is a tensor with shape [5, 30]
# Split "value" into 3 tensors with sizes [4, 15, 11] along dimension 1
split0, split1, split2 = tf.split(value, [4, 15, 11], 1)
tf.shape(split0)  # [5, 4]
tf.shape(split1)  # [5, 15]
tf.shape(split2)  # [5, 11]
# Split "value" into 3 tensors along dimension 1
split0, split1, split2 = tf.split(value, num_or_size_splits=3, axis=1)
tf.shape(split0)  # [5, 10]


Args:

value: The Tensor to split.
num_or_size_splits: Either a 0-D integer Tensor indicating the number of splits along split_dim or a 1-D integer Tensor integer tensor containing the sizes of each output tensor along split_dim. If a scalar then it must evenly divide value.shape[axis]; otherwise the sum of sizes along the split dimension must match that of the value.
axis: A 0-D int32 Tensor. The dimension along which to split. Must be in the range [-rank(value), rank(value)). Defaults to 0.
num: Optional, used to specify the number of outputs when it cannot be inferred from the shape of size_splits.
name: A name for the operation (optional).

Returns:
if num_or_size_splits is a scalar returns num_or_size_splits Tensor objects; if num_or_size_splits is a 1-D Tensor returns num_or_size_splits.get_shape[0] Tensor objects resulting from splitting value.

Raises:
ValueError: If num is unspecified and cannot be inferred.


2. example 1

import tensorflow as tf
import numpy as np

batch_size = 1
num_steps = 6
num_input = 2

# x_anchor shape: (batch_size, n_steps, n_input)
x_anchor = tf.constant([[[0, 1],
                         [2, 3],
                         [4, 5],
                         [6, 7],
                         [8, 9],
                         [10, 11]]], dtype=np.float32)

# permute num_steps and batch_size
y_anchor = tf.transpose(x_anchor, perm=[1, 0, 2])

# (num_steps*batch_size, num_input)
y_reshape = tf.reshape(y_anchor, [num_steps * batch_size, num_input])

# Split data because rnn cell needs a list of inputs for the RNN inner loop
# n_steps * (batch_size, num_input)
y_split = tf.split(y_reshape, num_steps, 0)

with tf.Session() as sess:
    input_anchor = sess.run(x_anchor)
    print("type(input_anchor):")
    print(type(input_anchor))
    print("input_anchor.shape:")
    print(input_anchor.shape)
    print("
")

    output_anchor = sess.run(y_anchor)
    print("type(output_anchor):")
    print(type(output_anchor))
    print("output_anchor.shape:")
    print(output_anchor.shape)
    print("output_anchor:")
    print(output_anchor)
    print("
")

    output_reshape = sess.run(y_reshape)
    print("type(output_reshape):")
    print(type(output_reshape))
    print("output_reshape.shape:")
    print(output_reshape.shape)
    print("output_reshape:")
    print(output_reshape)
    print("
")

    output_split = sess.run(y_split)
    print("type(output_split):")
    print(type(output_split))
    print("output_split:")
    print(output_split)
    print("
")

    print("output_split[0-5]:")
    for step in range(num_steps):
        print("output_split[%d]:"%(step))
        print(output_split[step])
        print("type(output_split[%d]):%s"%(step, type(output_split[step])))
    print("
")

    print("output_split[0-5]:")
    for step in range(num_steps):
        print("[output_split[%d]]:" %(step))
        print([output_split[step]])
        print("type([output_split[%d]]):%s"%(step, type([output_split[step]])))


output:

type(input_anchor):
<type "numpy.ndarray">
input_anchor.shape:
(1, 6, 2)


type(output_anchor):
<type "numpy.ndarray">
output_anchor.shape:
(6, 1, 2)
output_anchor:
[[[  0.   1.]]

 [[  2.   3.]]

 [[  4.   5.]]

 [[  6.   7.]]

 [[  8.   9.]]

 [[ 10.  11.]]]


type(output_reshape):
<type "numpy.ndarray">
output_reshape.shape:
(6, 2)
output_reshape:
[[  0.   1.]
 [  2.   3.]
 [  4.   5.]
 [  6.   7.]
 [  8.   9.]
 [ 10.  11.]]


type(output_split):
<type "list">
output_split:
[array([[ 0.,  1.]], dtype=float32), array([[ 2.,  3.]], dtype=float32), array([[ 4.,  5.]], dtype=float32), array([[ 6.,  7.]], dtype=float32), array([[ 8.,  9.]], dtype=float32), array([[ 10.,  11.]], dtype=float32)]


output_split[0-5]:
output_split[0]:
[[ 0.  1.]]
type(output_split[0]):<type "numpy.ndarray">
output_split[1]:
[[ 2.  3.]]
type(output_split[1]):<type "numpy.ndarray">
output_split[2]:
[[ 4.  5.]]
type(output_split[2]):<type "numpy.ndarray">
output_split[3]:
[[ 6.  7.]]
type(output_split[3]):<type "numpy.ndarray">
output_split[4]:
[[ 8.  9.]]
type(output_split[4]):<type "numpy.ndarray">
output_split[5]:
[[ 10.  11.]]
type(output_split[5]):<type "numpy.ndarray">


output_split[0-5]:
[output_split[0]]:
[array([[ 0.,  1.]], dtype=float32)]
type([output_split[0]]):<type "list">
[output_split[1]]:
[array([[ 2.,  3.]], dtype=float32)]
type([output_split[1]]):<type "list">
[output_split[2]]:
[array([[ 4.,  5.]], dtype=float32)]
type([output_split[2]]):<type "list">
[output_split[3]]:
[array([[ 6.,  7.]], dtype=float32)]
type([output_split[3]]):<type "list">
[output_split[4]]:
[array([[ 8.,  9.]], dtype=float32)]
type([output_split[4]]):<type "list">
[output_split[5]]:
[array([[ 10.,  11.]], dtype=float32)]
type([output_split[5]]):<type "list">

Process finished with exit code 0

3.  example 2

import tensorflow as tf
import numpy as np

batch_size = 2
num_steps = 6
num_input = 2

# x_anchor shape: (batch_size, n_steps, n_input)
x_anchor = tf.constant([[[0,   1],
                         [2,   3],
                         [4,   5],
                         [6,   7],
                         [8,   9],
                         [10, 11]],
                        [[12, 13],
                         [14, 15],
                         [16, 17],
                         [18, 19],
                         [20, 21],
                         [22, 23]]], dtype=np.float32)

# permute num_steps and batch_size
y_anchor = tf.transpose(x_anchor, perm=[1, 0, 2])

# (num_steps*batch_size, num_input)
y_reshape = tf.reshape(y_anchor, [num_steps * batch_size, num_input])

# Split data because rnn cell needs a list of inputs for the RNN inner loop
# n_steps * (batch_size, num_input)
y_split = tf.split(y_reshape, num_steps, 0)

with tf.Session() as sess:
    input_anchor = sess.run(x_anchor)
    print("type(input_anchor):")
    print(type(input_anchor))
    print("input_anchor.shape:")
    print(input_anchor.shape)
    print("
")

    output_anchor = sess.run(y_anchor)
    print("type(output_anchor):")
    print(type(output_anchor))
    print("output_anchor.shape:")
    print(output_anchor.shape)
    print("output_anchor:")
    print(output_anchor)
    print("
")

    output_reshape = sess.run(y_reshape)
    print("type(output_reshape):")
    print(type(output_reshape))
    print("output_reshape.shape:")
    print(output_reshape.shape)
    print("output_reshape:")
    print(output_reshape)
    print("
")

    output_split = sess.run(y_split)
    print("type(output_split):")
    print(type(output_split))
    print("output_split:")
    print(output_split)
    print("
")

    print("output_split[0-5]:")
    for step in range(num_steps):
        print("output_split[%d]:"%(step))
        print(output_split[step])
        print("type(output_split[%d]):%s"%(step, type(output_split[step])))
    print("
")

    print("output_split[0-5]:")
    for step in range(num_steps):
        print("[output_split[%d]]:" %(step))
        print([output_split[step]])
        print("type([output_split[%d]]):%s"%(step, type([output_split[step]])))


output:

type(input_anchor):
<type "numpy.ndarray">
input_anchor.shape:
(2, 6, 2)


type(output_anchor):
<type "numpy.ndarray">
output_anchor.shape:
(6, 2, 2)
output_anchor:
[[[  0.   1.]
  [ 12.  13.]]

 [[  2.   3.]
  [ 14.  15.]]

 [[  4.   5.]
  [ 16.  17.]]

 [[  6.   7.]
  [ 18.  19.]]

 [[  8.   9.]
  [ 20.  21.]]

 [[ 10.  11.]
  [ 22.  23.]]]


type(output_reshape):
<type "numpy.ndarray">
output_reshape.shape:
(12, 2)
output_reshape:
[[  0.   1.]
 [ 12.  13.]
 [  2.   3.]
 [ 14.  15.]
 [  4.   5.]
 [ 16.  17.]
 [  6.   7.]
 [ 18.  19.]
 [  8.   9.]
 [ 20.  21.]
 [ 10.  11.]
 [ 22.  23.]]


type(output_split):
<type "list">
output_split:
[array([[  0.,   1.],
       [ 12.,  13.]], dtype=float32), array([[  2.,   3.],
       [ 14.,  15.]], dtype=float32), array([[  4.,   5.],
       [ 16.,  17.]], dtype=float32), array([[  6.,   7.],
       [ 18.,  19.]], dtype=float32), array([[  8.,   9.],
       [ 20.,  21.]], dtype=float32), array([[ 10.,  11.],
       [ 22.,  23.]], dtype=float32)]


output_split[0-5]:
output_split[0]:
[[  0.   1.]
 [ 12.  13.]]
type(output_split[0]):<type "numpy.ndarray">
output_split[1]:
[[  2.   3.]
 [ 14.  15.]]
type(output_split[1]):<type "numpy.ndarray">
output_split[2]:
[[  4.   5.]
 [ 16.  17.]]
type(output_split[2]):<type "numpy.ndarray">
output_split[3]:
[[  6.   7.]
 [ 18.  19.]]
type(output_split[3]):<type "numpy.ndarray">
output_split[4]:
[[  8.   9.]
 [ 20.  21.]]
type(output_split[4]):<type "numpy.ndarray">
output_split[5]:
[[ 10.  11.]
 [ 22.  23.]]
type(output_split[5]):<type "numpy.ndarray">


output_split[0-5]:
[output_split[0]]:
[array([[  0.,   1.],
       [ 12.,  13.]], dtype=float32)]
type([output_split[0]]):<type "list">
[output_split[1]]:
[array([[  2.,   3.],
       [ 14.,  15.]], dtype=float32)]
type([output_split[1]]):<type "list">
[output_split[2]]:
[array([[  4.,   5.],
       [ 16.,  17.]], dtype=float32)]
type([output_split[2]]):<type "list">
[output_split[3]]:
[array([[  6.,   7.],
       [ 18.,  19.]], dtype=float32)]
type([output_split[3]]):<type "list">
[output_split[4]]:
[array([[  8.,   9.],
       [ 20.,  21.]], dtype=float32)]
type([output_split[4]]):<type "list">
[output_split[5]]:
[array([[ 10.,  11.],
       [ 22.,  23.]], dtype=float32)]
type([output_split[5]]):<type "list">

Process finished with exit code 0

4.







声明:该文观点仅代表作者本人,牛骨文系教育信息发布平台,牛骨文仅提供信息存储空间服务。