【CS20-TF4DL】09 RNN 与语言模型

上节课我们了解了图像的风格迁移,接下来我们来看看另一个非常出名的网络结构 —— RNN,以及语言模型这一应用场景。

更新历史

  • 2019.08.18: 完成初稿

RNN 简介

我们前面接触到的全连接和 CNN 可以很好地完成线性回归和 logistic 回归,或者去识别特定的对象。但是,在处理我们日常用的自然语言时,前面两种模型就有点吃力了。为了要能够捕捉序列化数据中的信息,RNN 应运而生(最早由 Elman 在 1990 年 提出 )。RNN 与之前的模型最大的不同就是打破了只往一个方向传递计算结果的套路,允许神经元的输出再次作为自己的输入,对比如下图所示:

这样的网络结构单个比较容易画出来,一旦序列长度增加,指向自己的箭头不好处理,所以我们一般会沿着时间方向对模型进行展开:

RNN 有啥优点呢:

  1. 利用了数据的序列信息
  2. 降低了参数个数
  3. 给 NLP 提供了很大支持

BPTT

BPTT 是 Back-propagation through Time 的缩写,也就是用来为 RNN 设计的反向传播技术,这个和之前的反向传播的不同之处在于:

  1. 对于 RNN 来说,不同 timesteps 共享相同的参数,我们根据每个训练样本/批次的所有 timesteps 的梯度之和来更新这些参数(在 CNN 和全连接中,每一层都有独立的参数)
  2. 对于 RNN 来说,timesteps 会根据序列长度的变化而变化(在 CNN 和全连接中,网络的层数是确定的)

这里我们关注下第二点,如果我们的序列非常长(比如是一片文章,有 1000 个词),那么整个反向传播的计算量会非常大,与此同时可能造成梯度爆炸或梯度消失。更多关于 BPTT 的介绍可以参考 这里

为了避免对所有的 timesteps 进行参数更新,我们常常会限制 timesteps 的个数,称为 truncated BPTT。这个方法一方面加速了计算,但另一方面降低了 RNN 的学习能力(相当于只能往回看指定的 timesteps,而不是全部的信息)。

在 Tensorflow 中,我们需要在开始训练前就指定好具体的 timesteps 数量,也就是我们需要事先确定序列的长度。因此,要么把长度分成桶然后分桶训练,要么设定一个长度进行截取或填充。

Gated Recurrent Unit

在实践中,RNN 在捕捉长效依赖上非常差劲(大概就是句子的最后和开头相关,因为 timesteps 的限制,RNN 无法捕捉到这样的信息)。为了解决长效依赖的问题,人们开始使用 LSTM,有趣的是,LSTM 其实在上世纪 90 年代就已提出,只是现在因为计算力的增长,才重新开始受人关注。

关于 LSTM 结构的图很多,下面这张是比较清晰的( 来源

一个 LSTM Cell 包含 4 个门,一般来说记为 $i,o,f,c^\sim$,分别代表 输入门(input)、输出门(output)、遗忘门(forget)和候选门(candidate/new memory)。具体的公式这里就不详细列出,网上有很多相关资源。简单来说,我们可以认为这些门是用来控制信息流通的,每个门都有相同的维度:

  • 输入门:当前输入有多少可以通过
  • 遗忘门:上一个 state 有多少会进入下一个 state
  • 输出门:当前的输出有多少会传递到下一个 state
  • 候选门:根据上一 state 和当前输入计算出当前 state
  • 最终输出:把前面所有的组合起来,作为给下一个 state 的输出

LSTM 的目标是捕捉长效依赖,GRU 则是 LSTM 的一个简化版本,虽然从理论上来说 GRU 可以大幅降低计算量,但是实际使用时相对 LSTM 并没有特别大的提升。GRU 和 LSTM 的对比如下:

在 Tensorflow 我们更倾向于使用 GRU,因为 LSTM 实在太笨重了,GRU 的输出只有一个 state,而 LSTM 有俩,处理起来就要多花一点时间。

语言建模

语言建模常用的方法有:

  • N-gram
    • 最传统的方式,基于前面的几个词,预测后一个词
    • 需要非常大的词表,且无法生成词表之外的词
    • 内存占用巨大
  • Character-level
    • 优势:词表很小,不需要词嵌入,训练很快
    • 劣势:生成的句子不流畅(很多词汇可能是毫无意义的)
  • Hybrid 混合
    • 默认的单词级别,当遇到无法识别的 token 时,转为 character-level
  • Subword-Level
    • 输入和输出都是单词的子集
    • 保留 W 个最常出现的词,保留 S 个最常出现的符号,其他的切分为 character
    • 效果在英文上会比 word-level 和 character-level 好(但是在中文上,不能这么搞)
    • 比如 new company dreamworks interactive -> dre+ am+ wo+ rks: in+ te+ ra+ cti+ ve:

RNN 的一个使用场景是给定一系列的单词,然后预测后面可能出现的词汇,除了一个单词一个单词预测,也可以是一个字母一个字母预测。一般来说,在构造语言模型时,输入是一系列的词(或字母),输出是下一个词(或字母)的概率分布。关于 char-RNN 的更多内容可以参考 这里

接下来我们会构建俩 char-RNN 模型,一个基于特朗普的推特,另一个基于 arvix 摘要。arvix 摘要数据集包含 20466 个摘要,每个的长度都在 500-2000 个字符。特朗普的推特数据集包含 2018.2.15 日之前的所有推特,总共 19469 条,每条都小于 140 个字符。这两个数据集都做了简单的预处理:

  1. 替换所有的 URL 为 __HTTP__ (其实应该用更短的,比如 _U_
  2. 每句话添加结束符 _E_

我们生成的特朗普推特大概长这样:

I will be interviewed on @foxandfriends tonight at 10:00 P.M. and the #1 to construct the @WhiteHouse tonight at 10:00 P.M. Enjoy __HTTP__

No matter the truth and the world that the Fake News Media will be a great new book #Trump2016 __HTTP__ __HTTP__

我们生成的 arvix 摘要大概长这样:

“Deep learning neural network architectures can be used to best developing a new architectures contros of the training and max model parametrinal Networks (RNNs) outperform deep learning algorithm is easy to out unclears and can be used to train samples on the state-of-the-art RNN more effective Lorred can be used to best developing a new architectures contros of the training and max model and state-of-the-art deep learning algorithms to a similar pooling relevants. The space of a parameter to optimized hierarchy the state-of-the-art deep learning algorithms to a simple analytical pooling relevants. The space of algorithm is easy to outions of the network are allowed at training and many dectional representations are allow develop a groppose a network by a simple model interact that training algorithms to be the activities to maximul setting, …”

RNN 部分的构建代码,在 Tensorflow 2.0 版本有比较大的变化,原来的代码设计比较杂乱和繁琐,建议主要参考 Keras 来看

具体的代码请参考 19_char_rnn.py ,我们在执行过程中可以观察到学习过程中的不同阶段生成的结果,慢慢会从原来的乱码(需要训练较多轮),变成稍有意义的文字

# 第 56 轮
Iter 56. Loss 10518.56640625. Time 0.8621890544891357
	Hillary_$9esu(b%rI
	I@.Ji➡4aSOo
	RwSsP'Pgsp)                                                                                                                                                                                         t
	T#D8'YQ?cdn:chart_with_upwards_trend:KK

# 第 506 轮
Iter 506. Loss 8116.001953125. Time 0.44046592712402344
	HillaryBin tor and on the tor on the tor and and an the the the tor the the the the the tor in the the tor the tor the tor the tor an and the the tor and the tor the the the tor in and the tor the the the to
	Ir on the the tor the the tor and on the the the the tor and on the the tor the the the the tor be the the tor the the the tor he the tor and and the the tor the tor the tor for the the the the the the
	R%f an and the the se the the the tor the tor the tor and and on the the tor the tor and the the tor the the tor @Tremang the tor the the the tor the tor the the the tor the tor the tor the the tor in
	
# 第 1179 轮
Iter 1179. Loss 7250.015625. Time 0.9244301319122314
	Hillary and sour the has and he beat on the the pront the A and Bare us the the beat and the mere and the the for the for the pront the will the beat on the will and the Leat The U Ant and the the reating th
	I's the beat for the and and the pront be the and the the pront the the the will be the the got the pront and in the pront in the and on the and on the will the xand will the and the preat the ,alling
	Ring the beat Gor And and the has and in the Preat me the Wand will be the pront the preat the the and the beat and the preat the 0 will be the Nes and Fare sour the and the Manding and reat the in the
	
# 第 2043 轮 
Iter 2043. Loss 6292.2451171875. Time 0.699113130569458
	Hillary the stor the doner the U.S. will be the donger the United in the +recane the deal the on the has the American in the in the pronite the Country and the sour the beal the proned the America be the rea
	I will be a preat and the beal the will be the really really to a could the prople the proned the sour the will be the erenting the pronite the Sear of the Make and the deal will be the @BarackObama an

我们可以看到,为了得到一个比较好的结果,是需要训练非常多轮的!

我来评几句
登录后评论

已发表评论数()

相关站点

+订阅
热门文章