Tensorflow: Connecting Two Graphs Together using "import_graph_def"
Connecting Two Graphs Together using import_graph_def
¶
Caution: In return_elements
you need in specify the name of Tensors not Operations, e.g. tf.add(a, b, name='add')
this is an operation named "add" if you want the tensor version of this its name is add:<index>
, one safe way to do this is add:0
which seems to always exist.
In [1]:
import tensorflow as tf
INT = tf.int32
In [2]:
def graph_add():
# computes c = a + b
g = tf.Graph()
with g.as_default() as g:
a = tf.placeholder(INT, [], name='a')
b = tf.placeholder(INT, [], name='b')
c = tf.add(a, b, name='c')
return g.as_graph_def()
def graph_pow():
# computes e = d ^ 2
g = tf.Graph()
with g.as_default() as g:
d = tf.placeholder(INT, [], name='d')
e = tf.pow(d, 2, name='e')
return g.as_graph_def()
In [3]:
tf.reset_default_graph()
# input of the main graph
a = tf.placeholder(INT, [], name='a')
b = tf.placeholder(INT, [], name='b')
# connect a, b to graph_add, output is g1_c
[g1_c] = tf.import_graph_def(
graph_add(), input_map={'a': a,
'b': b}, return_elements=['c:0'])
# connect output of graph add, g1_c, as input of graph_pow, output is g2_e
[g2_e] = tf.import_graph_def(
graph_pow(), input_map={'d': g1_c}, return_elements=['e:0'])
# get results of g1_c and g2_e
with tf.Session() as sess:
c, e = sess.run([g1_c, g2_e], feed_dict={a: 10, b: 20})
print('a + b =', c)
print('(a + b)^2 =', e)