Tensorflow: Connecting Two Graphs Together using "import_graph_def"

Connecting Two Graphs Together using import_graph_def

Caution: In return_elements you need in specify the name of Tensors not Operations, e.g. tf.add(a, b, name='add') this is an operation named "add" if you want the tensor version of this its name is add:<index>, one safe way to do this is add:0 which seems to always exist.


In [1]:
import tensorflow as tf

INT = tf.int32


In [2]:
def graph_add():
    # computes c = a + b
    g = tf.Graph()
    with g.as_default() as g:
        a = tf.placeholder(INT, [], name='a')
        b = tf.placeholder(INT, [], name='b')
        c = tf.add(a, b, name='c')
    return g.as_graph_def()

def graph_pow():
    # computes e = d ^ 2
    g = tf.Graph()
    with g.as_default() as g:
        d = tf.placeholder(INT, [], name='d')
        e = tf.pow(d, 2, name='e')
    return g.as_graph_def()


In [3]:
tf.reset_default_graph()

# input of the main graph
a = tf.placeholder(INT, [], name='a')
b = tf.placeholder(INT, [], name='b')

# connect a, b to graph_add, output is g1_c
[g1_c] = tf.import_graph_def(
    graph_add(), input_map={'a': a,
                            'b': b}, return_elements=['c:0'])

# connect output of graph add, g1_c, as input of graph_pow, output is g2_e
[g2_e] = tf.import_graph_def(
    graph_pow(), input_map={'d': g1_c}, return_elements=['e:0'])

# get results of g1_c and g2_e
with tf.Session() as sess:
    c, e = sess.run([g1_c, g2_e], feed_dict={a: 10, b: 20})
    print('a + b =', c)
    print('(a + b)^2 =', e)


a + b = 30
(a + b)^2 = 900


Konpat Preechakul

Read more posts by this author.