tf.split (API r1.3)
tf.split (API r1.3)
1. tf.split
split(
value,
num_or_size_splits,
axis=0,
num=None,
name="split"
)
Defined in tensorflow/python/ops/array_ops.py.
See the guide: Tensor Transformations > Slicing and JoiningSplits a tensor into sub tensors.
If num_or_size_splits is an integer type, num_split, then splits value along dimension axis into num_split smaller tensors. Requires that num_split evenly divides value.shape[axis].
If num_or_size_splits is not an integer type, it is presumed to be a Tensor size_splits, then splits value into len(size_splits) pieces. The shape of the i-th piece has the same size as the value
except along dimension axis where the size is size_splits[i].
For example:
# "value" is a tensor with shape [5, 30] # Split "value" into 3 tensors with sizes [4, 15, 11] along dimension 1 split0, split1, split2 = tf.split(value, [4, 15, 11], 1) tf.shape(split0) # [5, 4] tf.shape(split1) # [5, 15] tf.shape(split2) # [5, 11] # Split "value" into 3 tensors along dimension 1 split0, split1, split2 = tf.split(value, num_or_size_splits=3, axis=1) tf.shape(split0) # [5, 10]
Args:
value: The Tensor to split.num_or_size_splits: Either a 0-D integer Tensor indicating the number of splits along split_dim or a 1-D integer Tensor integer tensor containing the sizes of each output tensor along split_dim. If a scalar then it must evenly divide value.shape[axis]; otherwise the sum of sizes along the split dimension must match that of the value.
axis: A 0-D int32 Tensor. The dimension along which to split. Must be in the range [-rank(value), rank(value)). Defaults to 0.
num: Optional, used to specify the number of outputs when it cannot be inferred from the shape of size_splits.
name: A name for the operation (optional).
Returns:
if num_or_size_splits is a scalar returns num_or_size_splits Tensor objects; if num_or_size_splits is a 1-D Tensor returns num_or_size_splits.get_shape[0] Tensor objects resulting from splitting value.
Raises:
ValueError: If num is unspecified and cannot be inferred.
2. example 1
import tensorflow as tf
import numpy as np
batch_size = 1
num_steps = 6
num_input = 2
# x_anchor shape: (batch_size, n_steps, n_input)
x_anchor = tf.constant([[[0, 1],
[2, 3],
[4, 5],
[6, 7],
[8, 9],
[10, 11]]], dtype=np.float32)
# permute num_steps and batch_size
y_anchor = tf.transpose(x_anchor, perm=[1, 0, 2])
# (num_steps*batch_size, num_input)
y_reshape = tf.reshape(y_anchor, [num_steps * batch_size, num_input])
# Split data because rnn cell needs a list of inputs for the RNN inner loop
# n_steps * (batch_size, num_input)
y_split = tf.split(y_reshape, num_steps, 0)
with tf.Session() as sess:
input_anchor = sess.run(x_anchor)
print("type(input_anchor):")
print(type(input_anchor))
print("input_anchor.shape:")
print(input_anchor.shape)
print("
")
output_anchor = sess.run(y_anchor)
print("type(output_anchor):")
print(type(output_anchor))
print("output_anchor.shape:")
print(output_anchor.shape)
print("output_anchor:")
print(output_anchor)
print("
")
output_reshape = sess.run(y_reshape)
print("type(output_reshape):")
print(type(output_reshape))
print("output_reshape.shape:")
print(output_reshape.shape)
print("output_reshape:")
print(output_reshape)
print("
")
output_split = sess.run(y_split)
print("type(output_split):")
print(type(output_split))
print("output_split:")
print(output_split)
print("
")
print("output_split[0-5]:")
for step in range(num_steps):
print("output_split[%d]:"%(step))
print(output_split[step])
print("type(output_split[%d]):%s"%(step, type(output_split[step])))
print("
")
print("output_split[0-5]:")
for step in range(num_steps):
print("[output_split[%d]]:" %(step))
print([output_split[step]])
print("type([output_split[%d]]):%s"%(step, type([output_split[step]])))
output:
type(input_anchor): <type "numpy.ndarray"> input_anchor.shape: (1, 6, 2) type(output_anchor): <type "numpy.ndarray"> output_anchor.shape: (6, 1, 2) output_anchor: [[[ 0. 1.]] [[ 2. 3.]] [[ 4. 5.]] [[ 6. 7.]] [[ 8. 9.]] [[ 10. 11.]]] type(output_reshape): <type "numpy.ndarray"> output_reshape.shape: (6, 2) output_reshape: [[ 0. 1.] [ 2. 3.] [ 4. 5.] [ 6. 7.] [ 8. 9.] [ 10. 11.]] type(output_split): <type "list"> output_split: [array([[ 0., 1.]], dtype=float32), array([[ 2., 3.]], dtype=float32), array([[ 4., 5.]], dtype=float32), array([[ 6., 7.]], dtype=float32), array([[ 8., 9.]], dtype=float32), array([[ 10., 11.]], dtype=float32)] output_split[0-5]: output_split[0]: [[ 0. 1.]] type(output_split[0]):<type "numpy.ndarray"> output_split[1]: [[ 2. 3.]] type(output_split[1]):<type "numpy.ndarray"> output_split[2]: [[ 4. 5.]] type(output_split[2]):<type "numpy.ndarray"> output_split[3]: [[ 6. 7.]] type(output_split[3]):<type "numpy.ndarray"> output_split[4]: [[ 8. 9.]] type(output_split[4]):<type "numpy.ndarray"> output_split[5]: [[ 10. 11.]] type(output_split[5]):<type "numpy.ndarray"> output_split[0-5]: [output_split[0]]: [array([[ 0., 1.]], dtype=float32)] type([output_split[0]]):<type "list"> [output_split[1]]: [array([[ 2., 3.]], dtype=float32)] type([output_split[1]]):<type "list"> [output_split[2]]: [array([[ 4., 5.]], dtype=float32)] type([output_split[2]]):<type "list"> [output_split[3]]: [array([[ 6., 7.]], dtype=float32)] type([output_split[3]]):<type "list"> [output_split[4]]: [array([[ 8., 9.]], dtype=float32)] type([output_split[4]]):<type "list"> [output_split[5]]: [array([[ 10., 11.]], dtype=float32)] type([output_split[5]]):<type "list"> Process finished with exit code 0
3. example 2
import tensorflow as tf
import numpy as np
batch_size = 2
num_steps = 6
num_input = 2
# x_anchor shape: (batch_size, n_steps, n_input)
x_anchor = tf.constant([[[0, 1],
[2, 3],
[4, 5],
[6, 7],
[8, 9],
[10, 11]],
[[12, 13],
[14, 15],
[16, 17],
[18, 19],
[20, 21],
[22, 23]]], dtype=np.float32)
# permute num_steps and batch_size
y_anchor = tf.transpose(x_anchor, perm=[1, 0, 2])
# (num_steps*batch_size, num_input)
y_reshape = tf.reshape(y_anchor, [num_steps * batch_size, num_input])
# Split data because rnn cell needs a list of inputs for the RNN inner loop
# n_steps * (batch_size, num_input)
y_split = tf.split(y_reshape, num_steps, 0)
with tf.Session() as sess:
input_anchor = sess.run(x_anchor)
print("type(input_anchor):")
print(type(input_anchor))
print("input_anchor.shape:")
print(input_anchor.shape)
print("
")
output_anchor = sess.run(y_anchor)
print("type(output_anchor):")
print(type(output_anchor))
print("output_anchor.shape:")
print(output_anchor.shape)
print("output_anchor:")
print(output_anchor)
print("
")
output_reshape = sess.run(y_reshape)
print("type(output_reshape):")
print(type(output_reshape))
print("output_reshape.shape:")
print(output_reshape.shape)
print("output_reshape:")
print(output_reshape)
print("
")
output_split = sess.run(y_split)
print("type(output_split):")
print(type(output_split))
print("output_split:")
print(output_split)
print("
")
print("output_split[0-5]:")
for step in range(num_steps):
print("output_split[%d]:"%(step))
print(output_split[step])
print("type(output_split[%d]):%s"%(step, type(output_split[step])))
print("
")
print("output_split[0-5]:")
for step in range(num_steps):
print("[output_split[%d]]:" %(step))
print([output_split[step]])
print("type([output_split[%d]]):%s"%(step, type([output_split[step]])))
output:
type(input_anchor):
<type "numpy.ndarray">
input_anchor.shape:
(2, 6, 2)
type(output_anchor):
<type "numpy.ndarray">
output_anchor.shape:
(6, 2, 2)
output_anchor:
[[[ 0. 1.]
[ 12. 13.]]
[[ 2. 3.]
[ 14. 15.]]
[[ 4. 5.]
[ 16. 17.]]
[[ 6. 7.]
[ 18. 19.]]
[[ 8. 9.]
[ 20. 21.]]
[[ 10. 11.]
[ 22. 23.]]]
type(output_reshape):
<type "numpy.ndarray">
output_reshape.shape:
(12, 2)
output_reshape:
[[ 0. 1.]
[ 12. 13.]
[ 2. 3.]
[ 14. 15.]
[ 4. 5.]
[ 16. 17.]
[ 6. 7.]
[ 18. 19.]
[ 8. 9.]
[ 20. 21.]
[ 10. 11.]
[ 22. 23.]]
type(output_split):
<type "list">
output_split:
[array([[ 0., 1.],
[ 12., 13.]], dtype=float32), array([[ 2., 3.],
[ 14., 15.]], dtype=float32), array([[ 4., 5.],
[ 16., 17.]], dtype=float32), array([[ 6., 7.],
[ 18., 19.]], dtype=float32), array([[ 8., 9.],
[ 20., 21.]], dtype=float32), array([[ 10., 11.],
[ 22., 23.]], dtype=float32)]
output_split[0-5]:
output_split[0]:
[[ 0. 1.]
[ 12. 13.]]
type(output_split[0]):<type "numpy.ndarray">
output_split[1]:
[[ 2. 3.]
[ 14. 15.]]
type(output_split[1]):<type "numpy.ndarray">
output_split[2]:
[[ 4. 5.]
[ 16. 17.]]
type(output_split[2]):<type "numpy.ndarray">
output_split[3]:
[[ 6. 7.]
[ 18. 19.]]
type(output_split[3]):<type "numpy.ndarray">
output_split[4]:
[[ 8. 9.]
[ 20. 21.]]
type(output_split[4]):<type "numpy.ndarray">
output_split[5]:
[[ 10. 11.]
[ 22. 23.]]
type(output_split[5]):<type "numpy.ndarray">
output_split[0-5]:
[output_split[0]]:
[array([[ 0., 1.],
[ 12., 13.]], dtype=float32)]
type([output_split[0]]):<type "list">
[output_split[1]]:
[array([[ 2., 3.],
[ 14., 15.]], dtype=float32)]
type([output_split[1]]):<type "list">
[output_split[2]]:
[array([[ 4., 5.],
[ 16., 17.]], dtype=float32)]
type([output_split[2]]):<type "list">
[output_split[3]]:
[array([[ 6., 7.],
[ 18., 19.]], dtype=float32)]
type([output_split[3]]):<type "list">
[output_split[4]]:
[array([[ 8., 9.],
[ 20., 21.]], dtype=float32)]
type([output_split[4]]):<type "list">
[output_split[5]]:
[array([[ 10., 11.],
[ 22., 23.]], dtype=float32)]
type([output_split[5]]):<type "list">
Process finished with exit code 04.
声明:该文观点仅代表作者本人,牛骨文系教育信息发布平台,牛骨文仅提供信息存储空间服务。
- 上一篇: 关于WM_NOTIFY及反射机制的来龙去脉
- 下一篇: tp5源码分析之视图
