发新帖

在不同的graph分别建train和infer模型的问题

[复制链接]
842 3

快来加入 TensorFlowers 大家庭!

您需要 登录 才可以下载或查看,没有帐号?加入社区

x
本帖最后由 tfmonster 于 2018-4-20 15:14 编辑

        使用BiLSTM+CRF完成序列标注任务,想按照nmt的代码风格把train模型和infer模型建在不同的图,主要是方便train模型训练的时候使用iterator喂数据,而在infer时直接placeholder方便后续封装成api直接调用。代码大致结构如下:

  1. class Net:
  2.         def __init__(self):
  3.                 self.train_graph = tf.Graph()
  4.                 self.infer_graph = tf.Graph()
  5.         
  6.         def build_model(mode, iterator=None, source=None, label=None)
  7.                 if mode == ModeKeys.TRAIN:
  8.                         (src,tgt) = iterator
  9.                 else:
  10.                         src,tgt = source, label
  11.                 '''
  12.                         build_rnn_model_code
  13.                 '''
  14.                 return

  15.         def train(self):
  16.                 with self.train_graph.as_default():
  17.                         iterator = get_iterator
  18.                         trian_model = build_model(mode=ModeKeys.TRAIN, iterator)
  19.                 with tf.Session(graph=self.train_graph) as sess:
  20.                         '''
  21.                                 train model
  22.                         '''
  23.         def infer(self):
  24.                 with self.infer_graph.as_default():
  25.                         source_placeholder = tf.placeholder(tf.int32, shape=([]))
  26.                         lable_placeholder = tf.placeholder(tf.int32, shape=([]))
  27.                         infer_model = build_model(mode=ModeKeys.INFER, source=source_placeholder , label=lable_placeholder )
  28.                 with tf.Session(graph=self.infer_graph) as sess:
  29.                         '''
  30.                                 infer
  31.                         '''

  32. def main():
  33.         net = NET()
  34.         net.train()
复制代码


        运行的时候总是报错must be from the same graph as Tensor, 问题是在iterator = get_iterator这一步, get_iterator就是用的nmt的iterator写法,我把get_iterator全部放在train里面依旧报这个错,求教该怎么处理。
我知道答案 回答被采纳将会获得10 金币 + 15 金币 已有3人回答
本楼点评(0) 收起

精彩评论3

舟3332  TF芽芽  发表于 2018-4-20 19:46:35 | 显示全部楼层
nmt 的代码风格方便发一个链接吗?

我想问一个相关的问题。你建立了两个 graph 那么中间的参数是怎么传递的呢?
本楼点评(0) 收起
tfmonster  TF荚荚  发表于 2018-4-20 20:04:08 | 显示全部楼层
舟3332 发表于 2018-4-20 19:46
nmt 的代码风格方便发一个链接吗?

我想问一个相关的问题。你建立了两个 graph 那么中间的参数是怎么传递 ...

https://github.com/tensorflow/nmt, 两个图里面构建的模型完全一样,用saver.restore恢复参数,具体见项目READEME的Tips & Tricks章节
本楼点评(0) 收起
tf剑客  TF荚荚  发表于 2018-4-21 14:25:55 | 显示全部楼层
  1. if mode == ModeKeys.TRAIN:
  2.               (src,tgt) = iterator
复制代码

这里也要把它放在train graph的上下文里吧
本楼点评(0) 收起
您需要登录后才可以回帖 登录 | 加入社区

本版积分规则

快速回复 返回顶部 返回列表