共计 2761 个字符,预计需要花费 7 分钟才能阅读完成。
图是静态的,无论是加减乘除,只是定义了各种计算关系,不会有实际的任何运算
图的组成
- 输入节点
- 模型参数
- OP
默认计算图
在TensorFlow中会自动维护一个默认的一个计算图,所以我们能够直接定义的tensor或者运算都会被转换为计算图上一个节点。
v1 = tf.constant(value=1,name='v1',shape=(1,2),dtype=tf.float32)
v2 = tf.constant(value=2,name='v2',shape=(1,2),dtype=tf.float32)
add = v1 + v2
with tf.Session() as sess:
# 判断v1所在的graph是否是默认的graph
print(v1.graph is tf.get_default_graph())
print(sess.run(add))
# 输出 True
# 输出 [[3. 3.]]
我们可以通过tf.get_default_graph()
来获取当前节点所在的计算图。我们通过判断v1
tensor所在的计算图和默认的计算图进行比较,发现v1
的值处于默认的计算图上,由此也验证了:TensorFlow会自动维护一个默认的计算图,并将我们的节点添加到默认的计算图上。
我们可以看到默认的计算图上有三个节点,分别是v1
和v1
节点,它们共同组成了add
节点。
创建Graph
我们可以通过tf.Graph()新增计算图,并通过as_default()将变量和计算添加在当前的计算图中,最后通过Session的graph=计算图来计算指定的计算图。
# 新增计算图
new_graph = tf.Graph()
with new_graph.as_default():
# 在新增的计算图中进行计算
v1 = tf.constant(value=3, name='v1', shape=(1, 2), dtype=tf.float32)
v2 = tf.constant(value=4, name='v2', shape=(1, 2), dtype=tf.float32)
add = v1 + v2
# 通过graph=new_graph指定Session所在的计算图
with tf.Session(graph=new_graph) as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(add))
# 在默认计算图中进行计算
v1 = tf.constant(value=1,name='v1',shape=(1,2),dtype=tf.float32)
v2 = tf.constant(value=2,name='v2',shape=(1,2),dtype=tf.float32)
add = v1 + v2
# 通过graph=tf.get_default_graph()指定Session所在默认的计算图
with tf.Session(graph=tf.get_default_graph()) as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(add))
# 输出:[[7. 7.]]
# 输出:[[3. 3.]]
带有PlaceHolder的计算图
import tensorflow as tf
a=tf.placeholder(dtype=tf.float32,shape=[1])
b=tf.placeholder(dtype=tf.float32,shape=[1])
c=a+b
init=tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print(sess.run(c,feed_dict={a:[2.1],b:[3.2]}))
[5.3]
多个图之间互不相干
import tensorflow as tf
g1=tf.Graph()
with g1.as_default():
v=tf.get_variable("v",[1],initializer=tf.zeros_initializer(dtype=tf.float32))
g2=tf.Graph()
with g2.as_default():
v=tf.get_variable("v",[1],initializer=tf.ones_initializer(dtype=tf.float32))
with tf.Session(graph=g1) as sess:
tf.initialize_all_variables().run()
with tf.variable_scope("",reuse=True): # 当reuse=True时,tf.get_variable只能获取指定命名空间内的已创建的变量
print(sess.run(tf.get_variable("v")))
with tf.Session(graph=g2) as sess:
tf.initialize_all_variables().run()
with tf.variable_scope("",reuse=True): # 当reuse=True时,tf.get_variable只能获取指定命名空间内的已创建的变量
print(sess.run(tf.get_variable("v")))
#输出:[0.] [1.]
跟图相关的一些操作
1、根据 tensor name 来获取对应的tensor
对应的方法 get_tensor_by_name
import tensorflow as tf
a=tf.placeholder(dtype=tf.float32,shape=[1],name='v1')
b=tf.placeholder(dtype=tf.float32,shape=[1],name='v2')
c=a+b
d=tf.add(a,b,name='add')
init=tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print(sess.run(c,feed_dict={a:[2.1],b:[3.2]}))
test1=tf.get_default_graph().get_tensor_by_name('add:0')
print(sess.run(test1,feed_dict={a:[1.0],b:[2.0]}))
[5.3]
[3.]
2、获取 operation 信息
对应的方法 get_operation_by_name
Q : with new_graph.as_default(): 在这里面运行tf.get_default_graph()获取的是什么图?
正文完
请博主喝杯咖啡吧!