位置: IT常识 - 正文

基于 transformers 的 generate() 方法实现多样化文本生成:参数含义和算法原理解读(基于transformers的nlp入门 pdf)

编辑:rootadmin
基于 transformers 的 generate() 方法实现多样化文本生成:参数含义和算法原理解读 一、前言

推荐整理分享基于 transformers 的 generate() 方法实现多样化文本生成:参数含义和算法原理解读(基于transformers的nlp入门 pdf),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:基于FOB镜检查的吸入性AIS分级系统,基于专业性的家校双向互动,需要家长的学校教育参与,基于核心素养下的大单元教学设计,基于transformers的nlp入门 pdf,基于网页的客服系统,基于核心素养下的大单元教学设计,基于transformers的nlp入门 pdf,基于是什么意思,内容如对您有帮助,希望把文章链接给更多的朋友!

最近在做文本生成,用到huggingface transformers库的文本生成 generate() 函数,是 GenerationMixin 类的实现(class transformers.generation_utils.GenerationMixin),是自回归文本生成预训练模型相关参数的集大成者。因此本文解读一下这些参数的含义以及常用的 Greedy Search、Beam Search、Sampling(Temperature、Top-k、Top-p)等各个算法的原理。

这个类对外提供的方法是 generate(),通过调参能完成以下事情:

greedy decoding:当 num_beams=1 而且 do_sample=False 时,调用 greedy_search()方法,每个step生成条件概率最高的词,因此生成单条文本。multinomial sampling:当 num_beams=1 且 do_sample=True 时,调用 sample() 方法,对词表做一个采样,而不是选条件概率最高的词,增加多样性。beam-search decoding:当 num_beams>1 且 do_sample=False 时,调用 beam_search() 方法,做一个 num_beams 的柱搜索,每次都是贪婪选择top N个柱。beam-search multinomial sampling:当 num_beams>1 且 do_sample=True 时,调用 beam_sample() 方法,相当于每次不再是贪婪选择top N个柱,而是加了一些采样。diverse beam-search decoding:当 num_beams>1 且 num_beam_groups>1 时,调用 group_beam_search() 方法。constrained beam-search decoding:当 constraints!=None 或者 force_words_ids!=None,实现可控文本生成。二、各输入参数含义

接下来分别看看各个输入参数(源代码):

我觉得对文本生成质量最有用的几个参数有:max_length、min_length、do_sample、top_k、top_p、repetition_penalty。接下来选择性地记录各个参数的含义。

inputs (torch.Tensor of varying shape depending on the modality, optional) — The sequence used as a prompt for the generation or as model inputs to the encoder. If None the method initializes it with bos_token_id and a batch size of 1. For decoder-only models inputs should of in the format of input_ids. For encoder-decoder models inputs can represent any of input_ids, input_values, input_features, or pixel_values.

inputs:输入prompt。如果为空,则用batch size为1的 bos_token_id 初始化。对于只有decoder的模型(GPT系列),输入需要是 input_ids;对于 encoder-decoder模型(BART、T5等),输入更多样化。

max_length (int, optional, defaults to model.config.max_length) — The maximum length of the sequence to be generated.

max_length:生成序列的最大长度。

min_length (int, optional, defaults to 10) — The minimum length of the sequence to be generated.

min_length:生成序列的最短长度,默认是10。

do_sample (bool, optional, defaults to False) — Whether or not to use sampling ; use greedy decoding otherwise.

do_sample:是否开启采样,默认是 False,即贪婪找最大条件概率的词。

early_stopping (bool, optional, defaults to False) — Whether to stop the beam search when at least num_beams sentences are finished per batch or not.

early_stopping:是否在至少生成 num_beams 个句子后停止 beam search,默认是False。

num_beams (int, optional, defaults to 1) — Number of beams for beam search. 1 means no beam search.

num_beams:默认是1,也就是不进行 beam search。

temperature (float, optional, defaults to 1.0) — The value used to module the next token probabilities.

默认是1.0,温度越低(小于1),softmax输出的贫富差距越大;温度越高,softmax差距越小。

top_k (int, optional, defaults to 50) — The number of highest probability vocabulary tokens to keep for top-k-filtering.

top_k:top-k-filtering 算法保留多少个 最高概率的词 作为候选,默认50。详见下文。

top_p (float, optional, defaults to 1.0) — If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.

top_p:已知生成各个词的总概率是1(即默认是1.0)如果top_p小于1,则从高到低累加直到top_p,取这前N个词作为候选。

typical_p (float, optional, defaults to 1.0) — The amount of probability mass from the original distribution to be considered in typical decoding. If set to 1.0 it takes no effect. See this paper for more details.

typical_p:典型采样(不知道能否这样翻译),默认值 1.0 此参数无效,主要思想:不总是从分布高概率区域中选词,而是从信息含量接近预期值typical_p(即接近模型的条件熵)的单词集合中采样。 论文:Typical Decoding for Natural Language Generation

repetition_penalty (float, optional, defaults to 1.0) — The parameter for repetition penalty. 1.0 means no penalty. See this paper for more details.

repetition_penalty:默认是1.0,重复词惩罚。 论文:CTRL: A CONDITIONAL TRANSFORMER LANGUAGE MODEL FOR CONTROLLABLE GENERATION

pad_token_id (int, optional) — The id of the padding token. bos_token_id (int, optional) — The id of the beginning-of-sequence token. eos_token_id (int, optional) — The id of the end-of-sequence token.

pad_token_id / bos_token_id / eos_token_id:填充词<PAD>、起始附<s>、结束符</s> 的id。

length_penalty (float, optional, defaults to 1.0) — Exponential penalty to the length. 1.0 means that the beam score is penalized by the sequence length. 0.0 means no penalty. Set to values < 0.0 in order to encourage the model to generate longer sequences, to a value > 0.0 in order to encourage the model to produce shorter sequences.

length_penalty:长度惩罚,默认是1.0。

length_penalty=1.0:beam search分数会受到生成序列长度的惩罚length_penalty=0.0:无惩罚length_penalty<0.0:鼓励模型生成长句子length_penalty>0.0:鼓励模型生成短句子

no_repeat_ngram_size (int, optional, defaults to 0) — If set to int > 0, all ngrams of that size can only occur once.

no_repeat_ngram_size:用于控制重复词生成,默认是0,如果大于0,则相应N-gram只出现一次

基于 transformers 的 generate() 方法实现多样化文本生成:参数含义和算法原理解读(基于transformers的nlp入门 pdf)

encoder_no_repeat_ngram_size (int, optional, defaults to 0) — If set to int > 0, all ngrams of that size that occur in the encoder_input_ids cannot occur in the decoder_input_ids.

encoder_no_repeat_ngram_size:也是用于控制重复词生成,默认是0,如果大于0,则encoder_input_ids的N-gram不会出现在 decoder_input_ids里。

bad_words_ids(List[List[int]], optional) — List of token ids that are not allowed to be generated. In order to get the token ids of the words that should not appear in the generated text, use tokenizer(bad_words, add_prefix_space=True, add_special_tokens=False).input_ids.

bad_words_ids:禁止生成的词id列表,可用 tokenizer(bad_words, add_prefix_space=True, add_special_tokens=False).input_ids 方法获取ids。

force_words_ids(List[List[int]] or List[List[List[int]]], optional) — List of token ids that must be generated. If given a List[List[int]], this is treated as a simple list of words that must be included, the opposite to bad_words_ids. If given List[List[List[int]]], this triggers a disjunctive constraint, where one can allow different forms of each word.

force_words_ids:跟上面的 bad_words_ids 相反,这个传入必须生成的token id 列表。如果ids格式是 [List[List[int]]],比如 [[1,2],[3,4]],则触发析取约束(Disjunctive Positive Constraint Decoding),大概意思就是可以生成一个单词不同的形式,比如“lonely”、“loneliness”等。 论文:Guided Generation of Cause and Effect

num_return_sequences(int, optional, defaults to 1) — The number of independently computed returned sequences for each element in the batch.

num_return_sequences:每条输入产生多少条输出序列,默认为1。

max_time:多少秒之后停止生成。

attention_mask:默认跟输入 input_ids 的shape一样,0代表mask,1代表不mask,被mask掉的token不参与计算注意力权重。

decoder_start_token_id:encoder-decoder架构的模型有可能解码起始符跟编码器不一样(比如[CLS]、<s>)时可指定一个int值。

num_beam_groups (int, optional, defaults to 1) :beam search的时候为了确保不同beam之间的多样性,可以将这些beam划分成group,详见论文 Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence Models。

diversity_penalty (float, optional, defaults to 0.0):如果在同一个step中某个beam生成的词和其他beam有相同的,那么就减去这个值作为惩罚,仅在 num_beam_groups 启用时这个值才有效。

prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]], optional):如果提供该函数,就会把beam search每个step限制在允许的token里搜索,否则不做约束。函数有2个输入,分别是 batch_id 和当前步的输入 input_ids,返回一个list,包含下个step允许的token。可用于条件约束生成。详见论文 Autoregressive Entity Retrieval。

output_attentions (bool, optional, defaults to False) :是否返回所有attention层的注意力矩阵值,默认False。

output_hidden_states (bool, optional, defaults to False):是否返回各个层的hidden_states,默认是False。

output_scores (bool, optional, defaults to False):是否返回预测分数。

forced_bos_token_id (int, optional):解码器在生成 decoder_start_token_id 对应token之后指定生成的token id,mBART这种多语言模型会用到,因为这个值一般用来区分target语种。

forced_eos_token_id (int, optional):达到最大长度 max_length 时,强制作为最后生成的token id。

remove_invalid_values (bool, optional):是否删除模型nan(not a number)和inf(正无穷)防止崩溃,但可能会减慢生成速度。

exponential_decay_length_penalty (tuple(int, float), optional):生成一定数量的token之后,施加一个指数增长的长度惩罚,格式为 (start_index, decay_factor),前者表示从开始施加惩罚的索引,后者表示指数衰减因子。

三、函数输出含义

若 return_dict_in_generate=True 或者 config.return_dict_in_generate=True 时返回 ModelOutput 类对象(class transformers.utils.ModelOutput),否则返回 torch.FloatTensor。

四、各解码算法原理简述

本小节主要介绍自回归文本生成的几个最常用的解码方法,包括 Greedy search, Beam search, Top-K sampling 以及 Top-p sampling。自回归生成都是基于以下公式,也就是假设一个单词序列的概率分布等于各单词条件概率乘积。 P(w1:T∣W)=∏t=1TP(wt∣w1:t−1,W) ,with w1:=∅,P(w_{1:T} | W_0 ) = \prod_{t=1}^T P(w_{t} | w_{1: t-1}, W_0) \text{ ,with } w_{1: 0} = \emptyset,P(w1:T​∣W0​)=t=1∏T​P(wt​∣w1:t−1​,W0​) ,with w1:0​=∅,

4.1 Greedy Search

贪婪搜索,每个时间步 ttt 都选概率最高的那个词: wt=argmaxwP(w∣w1:t−1)w_t = argmax_{w}P(w | w_{1:t-1})wt​=argmaxw​P(w∣w1:t−1​) 比如图中,最终生成的序列是 (“The”,“nice”,“woman”)。这种贪婪算法和beam search的共同弊端就是容易生成重复词,小试一下: 此外,贪婪搜索容易忽略掉低概率词后面的高概率词,比如开头那个图里“the dog has”,概率是 0.4*0.9=0.36,比“the nice woman”的 0.5*0.4=0.20 要高,但由于第一轮dog概率比nice低,导致了与图中更优解擦肩而过。beam search 就能解决这个问题。

4.2 Beam Search

beam search每个时间步选择最可能的 Top - num_beams 个词,解决了贪婪搜索擦肩而过的风险。 如图例子,num_beams=2,第一步选了概率最高的序列 the nice(0.5) 和 the dog(0.4),第二步选了概率最高的序列 the dog has(0.4✖️0.9=0.36)和 the nice woman(0.5✖️0.4=0.20)。

注意,beam search虽然比贪婪搜索能找到概率更高的解,但不保证是全局最优解。

小试一下。设置num_beams > 1,early_stopping=True,当指定数量个beam生成结束符就早停。 比刚才好一些了,但还是有重复,可以加上no_repeat_ngram_size=2 禁止模型生成重复的 2-gram。但是需要慎用,因为“喜欢”这个词生成完后面就不能生成了,这就导致“喜欢周杰伦”没了。 此外,还可以通过 num_return_sequences 参数指定返回概率最高的 topN 个序列。 可见生成的这个top5个序列差异不算太大。

关于beam search,有这么三个说法:

如果生成长度是提前可预知,比如摘要、翻译,这种用beam search好;但开放式生成,比如对话,故事生成等,输出长度变化比较大,就不太适合用beam search了。beam search容易重复生成单词。由于通过大量实验才能达到“禁止生成重复n-gram” 和 “允许周期性生成重复n-gram” 的平衡,所以在开放式生成任务上不太好用这种惩罚来控制重复。人类往往说话时不是总选择高概率的词作为下个词,而是经常猝不及防,出其不意,如图对比。所以beam search还是有很大问题的。 4.3 Sampling

采样算法不再拘泥于高概率词,而是根据条件概率分布随机挑选单词,如图,car这种低概率词也有机会被选中作为生成文本。 generate函数中,设置 do_sample=True,并通过 top_k=0 先暂时停用topk采样,来看看实际效果。 可以看到模型有点胡言乱语的感觉…这时候,temperature参数就派上用场了。

4.3.0 Temperature

temperature参数相当于给softmax降降温,让各个词概率差距加大(跟刚才的随机 sample 相比,增加了高概率词的可能性,降低了低概率词的可能性)公式如下: 对比一下:下图是加了 temperature 的。 下图是没加 temperature 的(默认是1.0)。 可以看到:

T越小,趋近于0,概率密度就越集中在高概率词,就更偏向贪婪搜索,更容易产生重复词。T越大,趋近于1,就越趋向原始softmax,随机性就越大。T越大,甚至大于1,采样越随机,概率分布越趋向均匀分布。 小试一下: temperature = 0.7 时: temperature = 0.1 时: 4.3.1 Top-k 采样

Hierarchical Neural Story Generation 提出了 Top-K 采样方法,原理是先找出K个最有可能的单词,然后在这K个单词中计算概率分布,如图(蓝色为各个step的TopK)。GPT2就用了这种采样方法。 可见比之前要好很多,但是存在问题:

top-k 采样的问题就是这个 K 是死的,没法动态调整,这就导致上面例子中,左图 t=1 步骤,概率分布比较平缓,右图 t=2步骤,概率分布比较悬殊上面例子中,给定the后,t=1选的词还算合理,但t=2时,down 和 a 显然都不太适合但也被选到候选集里了

因此,将候选集限制为固定值K个,可能让模型在右图的悬殊分布里生成胡言乱语,也限制了在平缓分布中的一些创造性。所以top-p采样应运而生。

4.3.2 Top-p 采样

Top-p (nucleus) sampling 是 Ari Holtzman et al. (2019) 提出的算法。他是从使得累计概率超过 p 的最小候选集里选择单词,然后算这些单词的概率分布。这样候选单词集的大小就不跟topK似的一成不变了,会随下一个单词的概率分布动态增加和减少。 比如设置 p = 0.92,给定the后,t=1时前面9个词加起来概率为0.94,刚刚超过了0.92,于是前9个词成了候选词;t=2时前面3个词概率加起来已经达到了0.97。

也就是说,当下个单词不太可预测时,那候选就多一些;如果下个单词模型打眼一看就知道是哪些,那候选就少一些。

top_p 是 0-1 之间的值,值越接近1效果越好。 当 p 设置的比较大时,top-p 采样出来的候选词可能巨多,所以可以跟 top-k 结合起来用,避免那些 top-p 选中的概率很低的词,如图设置 top_p 和 top_k。

官方文档:https://huggingface.co/docs/transformers/v4.20.1/en/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate

参考资料: https://zhuanlan.zhihu.com/p/115076102 https://zhuanlan.zhihu.com/p/453286395 https://huggingface.co/blog/how-to-generate

本文链接地址:https://www.jiuchutong.com/zhishi/297532.html 转载请保留说明!

上一篇:CSS 获取当前可视屏幕高度--使用calc()方法动态计算宽度或者高度(css获取id)

下一篇:clone下来的vue项目出现“An unknown git error occurred”,vue全局挂载axios及配置全局请求和响应拦截,uni-app的全局请求和响应拦截,对请求方法的封装(vue clonedeep)

  • IPhone12pro激光雷达有什么用(iPhone12pro激光雷达怎么打开)

    IPhone12pro激光雷达有什么用(iPhone12pro激光雷达怎么打开)

  • 拼多多拼小圈怎么关闭(拼多多拼小圈怎么关闭自己的动态)

    拼多多拼小圈怎么关闭(拼多多拼小圈怎么关闭自己的动态)

  • 青骄第二课堂登录不了是怎么回事(青骄第二课堂登录入口进入2023)

    青骄第二课堂登录不了是怎么回事(青骄第二课堂登录入口进入2023)

  • 微信标题错了怎么补救(微信标题错了怎么办)

    微信标题错了怎么补救(微信标题错了怎么办)

  • 华为mate30刷新率是多少Hz(华为mate30刷新率多少)

    华为mate30刷新率是多少Hz(华为mate30刷新率多少)

  • 苹果手机用户正忙是什么意思(苹果手机用户正忙是不是拉黑了)

    苹果手机用户正忙是什么意思(苹果手机用户正忙是不是拉黑了)

  • 华为nova7se和nova7手机壳一样吗(华为nova7se和nova7哪个好)

    华为nova7se和nova7手机壳一样吗(华为nova7se和nova7哪个好)

  • 为什么微信记录没删就没了(为什么微信记录只能保存2月)

    为什么微信记录没删就没了(为什么微信记录只能保存2月)

  • 微信删了对方对方还有聊天记录吗(微信删了对方对方还能看到聊天记录吗)

    微信删了对方对方还有聊天记录吗(微信删了对方对方还能看到聊天记录吗)

  • 闲鱼不发货多久自动取消订单(闲鱼不发货多久自动退货)

    闲鱼不发货多久自动取消订单(闲鱼不发货多久自动退货)

  • 4g高清通话功能是什么意思

    4g高清通话功能是什么意思

  • lxe文件用什么软件打开(lx文件用什么打开)

    lxe文件用什么软件打开(lx文件用什么打开)

  • 如何将歌曲下载到u盘(如何将歌曲下载到文件夹)

    如何将歌曲下载到u盘(如何将歌曲下载到文件夹)

  • 淘宝购买失败系统繁忙什么原因(淘宝购买失败稍后再试)

    淘宝购买失败系统繁忙什么原因(淘宝购买失败稍后再试)

  • 拼多多gmv什么意思(拼多多gmv计算公式)

    拼多多gmv什么意思(拼多多gmv计算公式)

  • 新买的手机就卡怎么办(新买的手机就卡了怎么回事)

    新买的手机就卡怎么办(新买的手机就卡了怎么回事)

  • qq自定义标识怎么弄(qq自定义标识怎么设置)

    qq自定义标识怎么弄(qq自定义标识怎么设置)

  • 计算机中cpu是指什么(计算机术语中cpu是指)

    计算机中cpu是指什么(计算机术语中cpu是指)

  • 固态硬盘和硬盘都要吗(固态硬盘和硬盘容量的区别)

    固态硬盘和硬盘都要吗(固态硬盘和硬盘容量的区别)

  • 苹果11为什么充电那么慢(苹果11为什么充电时摸边上麻麻的还抖)

    苹果11为什么充电那么慢(苹果11为什么充电时摸边上麻麻的还抖)

  • 手机怎么屏蔽垃圾短信(手机怎么屏蔽垃圾广告短信)

    手机怎么屏蔽垃圾短信(手机怎么屏蔽垃圾广告短信)

  • 通知栏hd收费吗(通知栏里hd)

    通知栏hd收费吗(通知栏里hd)

  • 魅族16s怎么安装SIM卡(魅族16s怎么安装两个微信)

    魅族16s怎么安装SIM卡(魅族16s怎么安装两个微信)

  • 改群名片是什么意思(群名片在哪儿改)

    改群名片是什么意思(群名片在哪儿改)

  • 拼多多双收藏怎么截图(拼多多双收藏点哪里)

    拼多多双收藏怎么截图(拼多多双收藏点哪里)

  • 探探怎么解除手机号码(探探怎么样才能解除)

    探探怎么解除手机号码(探探怎么样才能解除)

  • 苹果手机怎么设置信任开发者(苹果手机怎么设置应用锁)

    苹果手机怎么设置信任开发者(苹果手机怎么设置应用锁)

  • Linux系统中用户的登入登出命令详解(linux系统中用户账户有哪些分类)

    Linux系统中用户的登入登出命令详解(linux系统中用户账户有哪些分类)

  • vue项目遇见事件冒泡如何处理(vue事件bus)

    vue项目遇见事件冒泡如何处理(vue事件bus)

  • 企业产生的所得税计入
  • 非居民纳税机构都包含哪些?
  • 环境保护税的应税污染物有哪些
  • 个体定期定额怎么征税2023
  • 企业所得税汇算清缴时间
  • 小微企业一般要交什么费用2019
  • 原材料入库汇总单
  • 开发票系统税号0和o怎么区别
  • 工伤事故赔偿项目表
  • 企业所得税应税所得率
  • 出售无形资产计入资产处置损益还是营业外收入
  • 公众号注册验证方式
  • 辅助生产交互分配后的实际费用应在进行分配
  • 个人去税务局开劳务费税率
  • 出租车发票日期可以改吗
  • 当年缴纳的税金怎么入账
  • 退货没有红字发票怎么办
  • 劳务派遣专用发票超过9万怎么办理
  • 建安发票税率是多少2011年
  • 附加税里包括地税吗
  • 酒店怎么缴纳增值税费用
  • 建筑劳务公司人员结构
  • 销售库存商品是什么凭证
  • 折价购买债券是什么意思
  • 有未分配利润就有盈余返还吗?
  • 住房补贴计入个人所得税吗
  • 物流公司怎么进去工作的
  • 购进商品没收到货怎么办
  • 会计帐务处理程序
  • 开具红字发票后如何在申报表中填写?
  • 财政拨款收入的预算会计科目
  • 存货丢失取得赔偿
  • 银行初级证书全称
  • 收到预付卡发票分录
  • 增值税的计税依据包括消费税吗
  • laravel获取请求参数
  • php中strcmp函数
  • VUE -- defineExpose
  • phpseessid
  • 人工智能示例
  • Vue3 + Pinia 持久化存储
  • iis部署javaweb
  • 企业固定资产可以按照其价值和使用情况,确定采用某一
  • 二手车价格网站
  • 核定征收的纳税人能否享受六税两费减免
  • 为SQLite3提供一个ANSI到UTF8的互转函数
  • 土石方费用入什么科目
  • 企业所得税季报资产总额季初季末
  • 车辆购置税相关法律规定
  • 购买性支出和转移性支出的区别
  • 结算备付金账户是什么帐户
  • 个税年终奖计算方法2022税率表
  • 公司员工报销没有发票挂内账有风险吗
  • 制造费用和直接人工的关系
  • 红冲去年费用会计分录
  • 合同取得成本和销售费用
  • 企业会计制度对固定资产无入账价值怎么入账
  • 支付借款利息需要交税吗
  • 金税盘费用抵扣账务处理
  • 如何优化sql语句执行效率
  • sql server触发器主要针对下列语句创建
  • 苹果电脑win10系统打不开
  • ubuntu虚拟机怎么改用户名
  • linux模块的概念
  • centos named
  • 在Linux命令行中快速删除光标前的快捷键是什么?
  • cocos2d动画
  • 置顶如何设置固定顺序
  • opengl出错
  • nodejs怎么启动服务
  • unity获取物体的位置
  • androidstudio手机编程软件
  • 国家税务湖北税务局
  • 车价36万保险一般多少钱
  • 小规模纳税人购买车辆可以抵扣税吗
  • 12123怎么上传交强险
  • 济南高新区国家税务局
  • 房契税发票丢了能补办吗
  • 非关税壁垒英语翻译
  • 湖北退役士兵退伍费
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设