Tensorflow gather

1 minute read

tf.gather(
    params,
    indices,
    validate_indices=None,
    name=None,
    axis=0
)

#Gather slices from params axis axis according to indices.

For example:
# 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
tf.shape(tf.squeeze(t))  # [2, 3]

# 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
tf.shape(tf.squeeze(t, [2, 4]))  # [1, 2, 3, 1]

th.gather

Add Jupyter notebook path

# import sys,os
# sys.path.insert(0,"/home/home9/Documents/prj/ipython_work/deep learning/stock/scripts")
# print(sys.path)

import os
import sys
nb_dir = os.path.split(os.getcwd())[0]
nb_dir = nb_dir+'/stock-rnn-master/scripts/'
print (nb_dir)
sys.path.append(os.getcwd())
sys.path.append(nb_dir)
# if nb_dir not in sys.path:
#     sys.path.append(nb_dir)
print(sys.path)

tensor example

# https://www.tensorflow.org/api_guides/python/array_ops
import tensorflow as tf
import numpy as np
import pprint
tf.set_random_seed(777)  # for reproducibility

pp = pprint.PrettyPrinter(indent=4)
sess = tf.InteractiveSession()

t = tf.constant([1,2,3,4])
print(tf.shape(t).eval())
print(t.eval())

>> [4]
>> [1 2 3 4]

t = tf.constant([[1,2,3,4]])
print(tf.shape(t).eval())
print(t.eval())
>> [1 4]
>> [[1 2 3 4]]

x = [1, 4]
y = [2, 5]
z = [3, 6]

# Pack along first dim.
tarr = tf.stack([x, y, z],0).eval()

tarr
>> array([[1, 4],
       [2, 5],
       [3, 6]], dtype=int32)

tarr = tf.gather(tarr,[2,0])
# tf.shape(tarr).eval()
tarr.eval()

>> array([[3, 6],
       [1, 4]], dtype=int32)

t = np.array([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]])
pp.pprint(t.shape)
pp.pprint(t)

>> (2, 2, 3)
array([[[ 0,  1,  2],
        [ 3,  4,  5]],

       [[ 6,  7,  8],
        [ 9, 10, 11]]])

t1 = tf.transpose(t, [2, 1,0])
pp.pprint(sess.run(t1).shape)
pp.pprint(sess.run(t1))

>> (3, 2, 2)
array([[[ 0,  6],
        [ 3,  9]],

       [[ 1,  7],
        [ 4, 10]],

       [[ 2,  8],
        [ 5, 11]]])

t1.get_shape()[0]
>> Dimension(3)

last = tf.gather(t1,int(t1.get_shape()[0])-1,name="test_output")

last
>> <tf.Tensor 'test_output:0' shape=(2, 2) dtype=int64>

pp.pprint(sess.run(last).shape)
pp.pprint(sess.run(last))

>> (2, 2)
array([[ 2,  8],
       [ 5, 11]])