tf.split (API r1.3)
tf.split (API r1.3)
1. tf.split
split( value, num_or_size_splits, axis=0, num=None, name="split" )
Defined in tensorflow/python/ops/array_ops.py.
See the guide: Tensor Transformations > Slicing and JoiningSplits a tensor into sub tensors.
If num_or_size_splits is an integer type, num_split, then splits value along dimension axis into num_split smaller tensors. Requires that num_split evenly divides value.shape[axis].
If num_or_size_splits is not an integer type, it is presumed to be a Tensor size_splits, then splits value into len(size_splits) pieces. The shape of the i-th piece has the same size as the value
except along dimension axis where the size is size_splits[i].
For example:
# "value" is a tensor with shape [5, 30] # Split "value" into 3 tensors with sizes [4, 15, 11] along dimension 1 split0, split1, split2 = tf.split(value, [4, 15, 11], 1) tf.shape(split0) # [5, 4] tf.shape(split1) # [5, 15] tf.shape(split2) # [5, 11] # Split "value" into 3 tensors along dimension 1 split0, split1, split2 = tf.split(value, num_or_size_splits=3, axis=1) tf.shape(split0) # [5, 10]
Args:
value: The Tensor to split.num_or_size_splits: Either a 0-D integer Tensor indicating the number of splits along split_dim or a 1-D integer Tensor integer tensor containing the sizes of each output tensor along split_dim. If a scalar then it must evenly divide value.shape[axis]; otherwise the sum of sizes along the split dimension must match that of the value.
axis: A 0-D int32 Tensor. The dimension along which to split. Must be in the range [-rank(value), rank(value)). Defaults to 0.
num: Optional, used to specify the number of outputs when it cannot be inferred from the shape of size_splits.
name: A name for the operation (optional).
Returns:
if num_or_size_splits is a scalar returns num_or_size_splits Tensor objects; if num_or_size_splits is a 1-D Tensor returns num_or_size_splits.get_shape[0] Tensor objects resulting from splitting value.
Raises:
ValueError: If num is unspecified and cannot be inferred.
2. example 1
import tensorflow as tf import numpy as np batch_size = 1 num_steps = 6 num_input = 2 # x_anchor shape: (batch_size, n_steps, n_input) x_anchor = tf.constant([[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11]]], dtype=np.float32) # permute num_steps and batch_size y_anchor = tf.transpose(x_anchor, perm=[1, 0, 2]) # (num_steps*batch_size, num_input) y_reshape = tf.reshape(y_anchor, [num_steps * batch_size, num_input]) # Split data because rnn cell needs a list of inputs for the RNN inner loop # n_steps * (batch_size, num_input) y_split = tf.split(y_reshape, num_steps, 0) with tf.Session() as sess: input_anchor = sess.run(x_anchor) print("type(input_anchor):") print(type(input_anchor)) print("input_anchor.shape:") print(input_anchor.shape) print(" ") output_anchor = sess.run(y_anchor) print("type(output_anchor):") print(type(output_anchor)) print("output_anchor.shape:") print(output_anchor.shape) print("output_anchor:") print(output_anchor) print(" ") output_reshape = sess.run(y_reshape) print("type(output_reshape):") print(type(output_reshape)) print("output_reshape.shape:") print(output_reshape.shape) print("output_reshape:") print(output_reshape) print(" ") output_split = sess.run(y_split) print("type(output_split):") print(type(output_split)) print("output_split:") print(output_split) print(" ") print("output_split[0-5]:") for step in range(num_steps): print("output_split[%d]:"%(step)) print(output_split[step]) print("type(output_split[%d]):%s"%(step, type(output_split[step]))) print(" ") print("output_split[0-5]:") for step in range(num_steps): print("[output_split[%d]]:" %(step)) print([output_split[step]]) print("type([output_split[%d]]):%s"%(step, type([output_split[step]])))
output:
type(input_anchor): <type "numpy.ndarray"> input_anchor.shape: (1, 6, 2) type(output_anchor): <type "numpy.ndarray"> output_anchor.shape: (6, 1, 2) output_anchor: [[[ 0. 1.]] [[ 2. 3.]] [[ 4. 5.]] [[ 6. 7.]] [[ 8. 9.]] [[ 10. 11.]]] type(output_reshape): <type "numpy.ndarray"> output_reshape.shape: (6, 2) output_reshape: [[ 0. 1.] [ 2. 3.] [ 4. 5.] [ 6. 7.] [ 8. 9.] [ 10. 11.]] type(output_split): <type "list"> output_split: [array([[ 0., 1.]], dtype=float32), array([[ 2., 3.]], dtype=float32), array([[ 4., 5.]], dtype=float32), array([[ 6., 7.]], dtype=float32), array([[ 8., 9.]], dtype=float32), array([[ 10., 11.]], dtype=float32)] output_split[0-5]: output_split[0]: [[ 0. 1.]] type(output_split[0]):<type "numpy.ndarray"> output_split[1]: [[ 2. 3.]] type(output_split[1]):<type "numpy.ndarray"> output_split[2]: [[ 4. 5.]] type(output_split[2]):<type "numpy.ndarray"> output_split[3]: [[ 6. 7.]] type(output_split[3]):<type "numpy.ndarray"> output_split[4]: [[ 8. 9.]] type(output_split[4]):<type "numpy.ndarray"> output_split[5]: [[ 10. 11.]] type(output_split[5]):<type "numpy.ndarray"> output_split[0-5]: [output_split[0]]: [array([[ 0., 1.]], dtype=float32)] type([output_split[0]]):<type "list"> [output_split[1]]: [array([[ 2., 3.]], dtype=float32)] type([output_split[1]]):<type "list"> [output_split[2]]: [array([[ 4., 5.]], dtype=float32)] type([output_split[2]]):<type "list"> [output_split[3]]: [array([[ 6., 7.]], dtype=float32)] type([output_split[3]]):<type "list"> [output_split[4]]: [array([[ 8., 9.]], dtype=float32)] type([output_split[4]]):<type "list"> [output_split[5]]: [array([[ 10., 11.]], dtype=float32)] type([output_split[5]]):<type "list"> Process finished with exit code 0
3. example 2
import tensorflow as tf import numpy as np batch_size = 2 num_steps = 6 num_input = 2 # x_anchor shape: (batch_size, n_steps, n_input) x_anchor = tf.constant([[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11]], [[12, 13], [14, 15], [16, 17], [18, 19], [20, 21], [22, 23]]], dtype=np.float32) # permute num_steps and batch_size y_anchor = tf.transpose(x_anchor, perm=[1, 0, 2]) # (num_steps*batch_size, num_input) y_reshape = tf.reshape(y_anchor, [num_steps * batch_size, num_input]) # Split data because rnn cell needs a list of inputs for the RNN inner loop # n_steps * (batch_size, num_input) y_split = tf.split(y_reshape, num_steps, 0) with tf.Session() as sess: input_anchor = sess.run(x_anchor) print("type(input_anchor):") print(type(input_anchor)) print("input_anchor.shape:") print(input_anchor.shape) print(" ") output_anchor = sess.run(y_anchor) print("type(output_anchor):") print(type(output_anchor)) print("output_anchor.shape:") print(output_anchor.shape) print("output_anchor:") print(output_anchor) print(" ") output_reshape = sess.run(y_reshape) print("type(output_reshape):") print(type(output_reshape)) print("output_reshape.shape:") print(output_reshape.shape) print("output_reshape:") print(output_reshape) print(" ") output_split = sess.run(y_split) print("type(output_split):") print(type(output_split)) print("output_split:") print(output_split) print(" ") print("output_split[0-5]:") for step in range(num_steps): print("output_split[%d]:"%(step)) print(output_split[step]) print("type(output_split[%d]):%s"%(step, type(output_split[step]))) print(" ") print("output_split[0-5]:") for step in range(num_steps): print("[output_split[%d]]:" %(step)) print([output_split[step]]) print("type([output_split[%d]]):%s"%(step, type([output_split[step]])))
output:
type(input_anchor): <type "numpy.ndarray"> input_anchor.shape: (2, 6, 2) type(output_anchor): <type "numpy.ndarray"> output_anchor.shape: (6, 2, 2) output_anchor: [[[ 0. 1.] [ 12. 13.]] [[ 2. 3.] [ 14. 15.]] [[ 4. 5.] [ 16. 17.]] [[ 6. 7.] [ 18. 19.]] [[ 8. 9.] [ 20. 21.]] [[ 10. 11.] [ 22. 23.]]] type(output_reshape): <type "numpy.ndarray"> output_reshape.shape: (12, 2) output_reshape: [[ 0. 1.] [ 12. 13.] [ 2. 3.] [ 14. 15.] [ 4. 5.] [ 16. 17.] [ 6. 7.] [ 18. 19.] [ 8. 9.] [ 20. 21.] [ 10. 11.] [ 22. 23.]] type(output_split): <type "list"> output_split: [array([[ 0., 1.], [ 12., 13.]], dtype=float32), array([[ 2., 3.], [ 14., 15.]], dtype=float32), array([[ 4., 5.], [ 16., 17.]], dtype=float32), array([[ 6., 7.], [ 18., 19.]], dtype=float32), array([[ 8., 9.], [ 20., 21.]], dtype=float32), array([[ 10., 11.], [ 22., 23.]], dtype=float32)] output_split[0-5]: output_split[0]: [[ 0. 1.] [ 12. 13.]] type(output_split[0]):<type "numpy.ndarray"> output_split[1]: [[ 2. 3.] [ 14. 15.]] type(output_split[1]):<type "numpy.ndarray"> output_split[2]: [[ 4. 5.] [ 16. 17.]] type(output_split[2]):<type "numpy.ndarray"> output_split[3]: [[ 6. 7.] [ 18. 19.]] type(output_split[3]):<type "numpy.ndarray"> output_split[4]: [[ 8. 9.] [ 20. 21.]] type(output_split[4]):<type "numpy.ndarray"> output_split[5]: [[ 10. 11.] [ 22. 23.]] type(output_split[5]):<type "numpy.ndarray"> output_split[0-5]: [output_split[0]]: [array([[ 0., 1.], [ 12., 13.]], dtype=float32)] type([output_split[0]]):<type "list"> [output_split[1]]: [array([[ 2., 3.], [ 14., 15.]], dtype=float32)] type([output_split[1]]):<type "list"> [output_split[2]]: [array([[ 4., 5.], [ 16., 17.]], dtype=float32)] type([output_split[2]]):<type "list"> [output_split[3]]: [array([[ 6., 7.], [ 18., 19.]], dtype=float32)] type([output_split[3]]):<type "list"> [output_split[4]]: [array([[ 8., 9.], [ 20., 21.]], dtype=float32)] type([output_split[4]]):<type "list"> [output_split[5]]: [array([[ 10., 11.], [ 22., 23.]], dtype=float32)] type([output_split[5]]):<type "list"> Process finished with exit code 0
4.
声明:该文观点仅代表作者本人,牛骨文系教育信息发布平台,牛骨文仅提供信息存储空间服务。
- 上一篇: 关于WM_NOTIFY及反射机制的来龙去脉
- 下一篇: tp5源码分析之视图