相关动态
Pytorch LSTM实现中文单词预测(附完整训练代码)
2024-11-11 00:18

目录

Pytorch LSTM实现中文单词预测(附完整训练代码)

Pytorch LSTM实现中文单词预测(词语预测 附完整训练代码)

1、项目介绍

2、中文单词预测方法(N-Gram 模型

3、训练词嵌入word2vec(可选

4、文本预处理

(1)句子分词处理:jieba中文分词

(2)特殊字符处理

(3)文本数据增强

(4)样本均衡(重点

5、训练过程 

(1)项目框架说明

(2)准备文本数据

(3)配置文件:config.yaml

(4)开始训练

(6)一些优化建议

6. 模型测试效果

7.项目源码下载


本文将分享一个NLP项目实例,实现一个类似于中文输入法中联想的功能;项目利用深度学习框架Pytorch,构建一个LSTM(也支持NGram,TextCNN,LSTM,BiLSTM等)模型,实现一个简易的中文单词预测(词语预测)功能,该功能可以根据用户输入的中文语句,自动预测(补充)词语;基于该项目训练的中文单词预测(词语预测)模型,在自定义的数据集上Top-1准确率最高可以达到91%左右,Top-5准确率最高可以达到97%左右。

模型context_sizeembedding_dimTop-1准确率Top-3准确率Top-5准确率NGram81280.86300.91800.9357TextCNN81280.90650.96210.9730LSTM81280.9088

0.9535

0.9667BiLSTM81280.91000.95750.9673

如果,你想学习NLP中文文本分类,可参考另一篇博文《Pytorch TextCNN实现中文文本分类(附完整训练代码)》

【尊重原则,转载请注明出处】https://blog.csdn.net/guyuealian/article/details/128582675


首先简单介绍一下 N-Gram 模型的原理:对于一英文句话,单词的排列顺序是非常重要的,所以我们能否由前面的几个词来预测后面的几个单词呢,比如 'I lived in France for 10 years, I can speak _' 这句话中,我们能够预测出最后一个词是 French。

对于一句话T,其由w1,w2....wn这n个词构成,可以得到下面的公式

但是这样的一个模型参数过大,预测一个词需要前面所有的词作为条件来计算概率。我们可以再简化一下这个模型,比如对于一个词,并不需要前面所有的词作为条件概率,也就是说一个词可以只与其前面的几个词有关,这就是马尔科夫假设

对于这里的条件概率,传统的方法是统计语料中每个词出现的频率,根据贝叶斯定理来估计这个条件概率,这里我们就可以用词嵌入对其进行代替,然后使用 RNN 进行条件概率的计算,然后最大化这个条件概率不仅修改词嵌入,同时能够使得模型可以依据计算的条件概率对其中的一个单词进行预测。

类似的,对于中文,我们也可以这样进行处理,比如一个简单的语句:【我是一名中国人】,我们希望如果我们输入【我是一名】,模型输出结果是【中国人】;只不过中文不像英文那样有明显空格作为单词分隔符,中文语句需要我们自己按照一定规则进行字词分割,这个工具可以使用jieba中文分词工具。

下面定义一个简单的NGram模型

 

上面我们定义了一个NGram模型,其中参数 context_size 表示句子最大长度,表示我们希望由前面几个单词来预测这个单词,这里context_size=8个,表示由8单词(不足8个可以填充)预测1个单词;embedding_dim 表示词嵌入的维度,即词向量特征长度。num_embeddings是词典的大小,即我们字库的个数大小,由于输出预测的类别数等于词典的字词数的大小,所num_embeddings = num_classes

NGram模型比较简单如果去除Embedding层,其实就是一个由2层全连接层构成的模型;在实际应用中,效果比较差的,只能作为介绍使用;后续项目,将以LSTM模型为例,进行训练和测试,LSTM模型定义如下,其中我增加LayerNorm用于归一化数据,其作用类似于CNN中的BatchNorm。

 

不管是CNN还是RNN模型,都是无法直接处理字符类别的单词,因此我们需要对单词进行编码,即通过某种方法把单词变成数字形式的向量才能作为模型的输入。把单词映射到向量空间中的一个向量的做法称为词嵌入(word embedding,对应的向量称为词向量(word vector

上面的NGram模型代码中,定义了一个可学习的embedding层,即词嵌入word2vec,其作用就是将word序号ID转换为vector;当然你也可以通过gensim训练自己的word2vec模型,然后在数据处理中先将文本转换为词向量,这样NGram就没有必要添加embedding层了。


接下来,我们需要将句子按照context_size+1的长度进行逐个截取,前context_size个单词是模型输入数据,最后一个是预测结果,这样就构建了我们的训练集,核心代码如下

 

对于中文文本数据预处理,主要有两部分:句子分词处理(英文文本不需要分词),特殊字符处理

(1)句子分词处理:jieba中文分词

本博客使用jieba工具进行中文分词,工具比较简单,就不单独说明了,安装方法

 

(2)特殊字符处理

jieba分词后,会出现很多特殊字符,需要进一步做一些的处理

  • 一些换行符,空格等特殊字符,以及一些标点符号,。?《》)等,这些特殊的字符称为stop_words,需要剔除
  • 一些英文字母大小需要转换统一为小写
  • 一些繁体字统一转换为简体字等
  • 一些专有名词,比如地名,人名这些,分词时需要整体切词:jieba.load_userdict(file)

(3)文本数据增强

在计算机视觉图像识别任务中,图像数据增强主要有:裁剪、翻转、旋转、⾊彩变换等⽅式,其目的增加数据的多样性,提高模型的泛化能力。但是NLP任务中的数据是离散的,无法像操作图片一样连续的方式操作文字,这导致我们⽆法对输⼊数据进⾏直接简单地转换,换掉⼀个词就有可能改变整个句⼦的含义。

常用的NLP文本数据增强方法主要有

  • 随机截取: 随机截取文本一个片段
  • 同义词替换(SR: Synonyms Replace:不考虑stopwords,在句⼦中随机抽取n个词,然后从同义词词典中随机抽取同义词,并进⾏替换。
  • 随机插⼊(RI: Randomly Insert):不考虑stopwords,随机抽取⼀个词,然后在该词的同义词集合中随机选择⼀个,插⼊原句⼦中的随机位置。
  • 随机交换(RS: Randomly Swap):句⼦中,随机选择两个词,位置交换。
  • 随机删除(RD: Randomly Delete):句⼦中的每个词,以概率p随机删除

(4)样本均衡(重点

有一些常用词汇,由于其出现的频率很高,导致模型预测的结果,会偏向于预测高频率出现的词汇;比如【的】字,在中文语句中,出现的频率特别大,导致模型预测的时候,输出结果经常被预测为【的】,显示这是不符合实际情况;一种行之有效的解决方法,是对数据进行均衡采用,即高频词汇应该降低采样次数,而低频词应该增加其采用次数。

项目已经实现样本均衡算法,config.yaml配置文件中,只需要设置resample: True即可开启样本均衡训练

项目已经实现随机截取,随机插⼊,随机删除等几种文本数据增强方式

 

(1)项目框架说明

 

 项目依赖的python包,请使用pip安装对应版本

 

(2)准备文本数据

首先,我们需要收集中文文本数据集,由于我们是做单词预测算法,要求训练数据尽可能干净;考虑到我们的模型比较简单,无需像BERT那样海量数据作为简单的Demo,项目从百度文库中收集了一些中文造句的常用句子,大概5千字的数据量吧

 然后根据自己的保存的数据路径,修改配置文件数据路径:configs/config.yaml (项目的文本数据放在data/text/data"中,可自行增加补充数据;考虑到,单词预测是一种比较模糊预测的任务,因此,项目没有严格区分训练集和测试集,而是将训练数据和测试集都使用同一数据集。

 

(3)配置文件:config.yaml

 
  • 目标支持模型主要有NGram,TextCNN,LSTM,BiLSTM,详见模型等 ,其他模型可以自定义添加
  • 训练参数可以通过config.yaml配置文件
参数类型参考值说明train_datastr, list-训练数据文件,可支持多个文件test_datastr, list-测试数据文件,可支持多个文件vocab_filestr-
字典文件(会根据训练数据集自动生成)
class_namestr-类别文件data_typestr-加载数据DataLoader方法resampleboolTrue是否进行重采样work_dirstrwork_space训练输出工作空间net_typestrLSTM骨干网络,支持:NGram,TextCNN,LSTM,BiLSTM等模型context_sizeint128句子长度topklist[1,3,5]计算topK的准确率batch_sizeint32批训练大小lrfloat0.1初始学习率大小optim_typestrSGD优化器,{SGD,Adam}loss_typestrCELoss损失函数schedulerstrmulti-step学习率调整策略,{multi-step,cosine}milestoneslist[30,80,100]降低学习率的节点,仅仅scheduler=multi-step有效momentumfloat0.9SGD动量因子num_epochsint120循环训练的次数num_workersint12DataLoader开启线程数weight_decayfloat5e-4权重衰减系数gpu_idlist[ 0 ]指定训练的GPU卡号,可指定多个log_freqint20显示LOG信息的频率finetunestrmodel.pthfinetune的模型

(4)开始训练

整套训练代码非常简单操作,用户只需要将文本数据放在项目"data/text/data"目录下(也可以自定义数据路径,并填写好对应的数据路径,即可开始训练了。

  • 如果你想验证项目可不可以训练,请运行下面命令开始训练;项目自带了小批量的文本数据,方便测试项目代码;对于简单的样本数据集,可以获得95%左右的预测准确率
 
  •  如果你想训练自己的数据,请准备好文本数据,并放在data/text/data中,文本只支持TXT格式,不支持PDF和word文档格式
  • 配置文件configs/config.yaml的参数net_type,用于选择骨干网络,可以填写NGram, TextCNN, LSTM, BiLSTM等,后面模型以LSTM模型为准

以下是训练代码

 

(5)可视化训练过程

训练过程可视化工具是使用Tensorboard,使用方法
 

可视化效果 

​​​​​​​​​​​​​​​​​​​

(6)一些优化建议

训练完成后,可查看其Top-1,Top-3和Top-5准确率其中NGram的Top-1准确率约0.8630,而TextCNN的准确率约0.9065,LSTM的准确率约0.9088,BiLSTM的准确率最高可以达到0.9100

模型context_sizeembedding_dimTop-1准确率Top-3准确率Top-5准确率NGram81280.86300.91800.9357TextCNN81280.90650.96210.9730LSTM81280.9088

0.9535

0.9667BiLSTM81280.91000.95750.9673
  1. 数据整合:建议对数据进行去燥,删除一些语句不通的文本
  2. 由于数据集比较小,有很多中文字词是不支持,建议增大数据集进行训练
  3. 增加LSTM参数量:比如增大LSTM的个数
  4. 增加pretrained模型:项目构建LSTM模型,随机初始化了一个可学习的二维矩阵:Embedding,该Embedding模型没有增加pretrained的,若能加入pretrained,其准确率会好很多。
  5. 文本数据增强:如同义词替换,文本随机插入,随机删除等处理,增强模型泛化能力
  6. 样本均衡:数据不均衡,部分类目数据太少; 建议进行样本均衡处理,减少长尾问题的影响
  7. 超参调优: 比如学习率调整策略,优化器(SGD,Adam等
  8. 损失函数: 目前训练代码已经支持:交叉熵,LabelSmoothing,可以尝试FocalLoss等损失函数

predictor.py文件用于模型推理和测试脚本,填写好配置文件,模型文件以及测试文本即可运行测试了

 

或者在项目根目录终端运行命令(表示换行符

 

使用方法: 

  1. 【使用说明】:输入任意文本,用[空格]或[/]表示需要预测的字词;输入[e]退出程序
  2. 【输入例子】:美丽豪华的/获得优异的
  3. 【结果说明】:输出括号内表示预测结果

运行测试结果:  

  • 输入"我们家住的楼上有许多只/每天晚上你都能看到 一双圆溜溜的 拖着一条长长的",返回的预测结果
  • 输入"美丽豪华的 获得优异的,返回的预测结果


整套项目源码下载:Pytorch LSTM实现中文单词预测(词语预测)

整套项目源码内容包含

  • 提供中文文本数据集,用于模型训练:数据主要从百度文库中收集了一些中文造句的常用句子,大概5千字的数据量
  • 提供Pytorch版本的中文单词预测模型训练工具:train.py,支持NGram,TextCNN, LSTM, BiLSTM等模型训练和测试
  • 提供中文单词预测测试脚本:predictor.py
  • 项目已经实现样本均衡算法,config.yaml配置文件中,只需要设置resample: True即可开启样本均衡训练
  • 简单配置,一键开启训练自己的中文单词预测模型
    以上就是本篇文章【Pytorch LSTM实现中文单词预测(附完整训练代码)】的全部内容了,欢迎阅览 ! 文章地址:http://sjzytwl.xhstdz.com/quote/74949.html 
     栏目首页      相关文章      动态      同类文章      热门文章      网站地图      返回首页 物流园资讯移动站 http://sjzytwl.xhstdz.com/mobile/ , 查看更多   
发表评论
0评