你还弄不清xxxForCausalLM和xxxForConditionalGeneration吗

科技资讯 投稿 7000 0 评论

你还弄不清xxxForCausalLM和xxxForConditionalGeneration吗

Part1基本介绍

大语言模型目前一发不可收拾,在使用的时候经常会看到transformers库的踪影,其中xxxCausalLM和xxxForConditionalGeneration会经常出现在我们的视野中,接下来我们就来聊聊transformers库中的一些基本任务。

这里以三类模型为例:bert(自编码)、gpt(自回归)、bart(编码-解码)

首先我们整体看下每个模型有什么任务:

from ..bart.modeling_bart import (
    BartForCausalLM,
    BartForConditionalGeneration,
    BartForQuestionAnswering,
    BartForSequenceClassification,
    BartModel,

from ..bert.modeling_bert import (
    BertForMaskedLM,
    BertForMultipleChoice,
    BertForNextSentencePrediction,
    BertForPreTraining,
    BertForQuestionAnswering,
    BertForSequenceClassification,
    BertForTokenClassification,
    BertLMHeadModel,
    BertModel,

from ..gpt2.modeling_gpt2 import GPT2ForSequenceClassification, GPT2LMHeadModel, GPT2Model

1Bert

  • BertModel(BertPreTrainedModel:最原始的bert,可获得句向量表示或者每个token的向量表示。

  • BertForPreTraining(BertPreTrainedModel:在BertModel的基础上加了一个预训练头:

self.bert = BertModel(config
self.cls = BertPreTrainingHeads(config
class BertPreTrainingHeads(nn.Module:
    def __init__(self, config:
        super(.__init__(
        self.predictions = BertLMPredictionHead(config
        self.seq_relationship = nn.Linear(config.hidden_size, 2
    def forward(self, sequence_output, pooled_output:
        prediction_scores = self.predictions(sequence_output
        seq_relationship_score = self.seq_relationship(pooled_output
        return prediction_scores, seq_relationship_score
class BertLMPredictionHead(nn.Module:
    def __init__(self, config:
        super(.__init__(
        self.transform = BertPredictionHeadTransform(config
        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False
        self.bias = nn.Parameter(torch.zeros(config.vocab_size
        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
        self.decoder.bias = self.bias
    def forward(self, hidden_states:
        hidden_states = self.transform(hidden_states
        hidden_states = self.decoder(hidden_states
        return hidden_states
对应bert的两个训练任务:掩码语言模型(MLM)和下一个句子预测(NSP)。
  • BertLMHeadModel(BertPreTrainedModel:MLM任务
self.bert = BertModel(config, add_pooling_layer=False
self.cls = BertOnlyMLMHead(config
class BertOnlyMLMHead(nn.Module:
    def __init__(self, config:
        super(.__init__(
        self.predictions = BertLMPredictionHead(config
    def forward(self, sequence_output: torch.Tensor -> torch.Tensor:
        prediction_scores = self.predictions(sequence_output
        return prediction_scores
class BertLMPredictionHead(nn.Module:
    def __init__(self, config:
        super(.__init__(
        self.transform = BertPredictionHeadTransform(config
        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False
        self.bias = nn.Parameter(torch.zeros(config.vocab_size
        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
        self.decoder.bias = self.bias
    def forward(self, hidden_states:
        hidden_states = self.transform(hidden_states
        hidden_states = self.decoder(hidden_states
        return hidden_states
  • BertForNextSentencePrediction(BertPreTrainedModel:NSP任务
self.bert = BertModel(config
self.cls = BertOnlyNSPHead(config
class BertOnlyNSPHead(nn.Module:
    def __init__(self, config:
        super(.__init__(
        self.seq_relationship = nn.Linear(config.hidden_size, 2
    def forward(self, pooled_output:
        seq_relationship_score = self.seq_relationship(pooled_output
        return seq_relationship_score
  • BertForSequenceClassification(BertPreTrainedModel:对句子进行分类
self.bert = BertModel(config
classifier_dropout = (
    config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob

self.dropout = nn.Dropout(classifier_dropout
self.classifier = nn.Linear(config.hidden_size, config.num_labels
  • BertForMultipleChoice(BertPreTrainedModel::多项选择
self.bert = BertModel(config
classifier_dropout = (
        config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
    
self.dropout = nn.Dropout(classifier_dropout
self.classifier = nn.Linear(config.hidden_size, 1
  • BertForTokenClassification(BertPreTrainedModel:对token进行分类,一般为命名实体识别任务
self.bert = BertModel(config, add_pooling_layer=False
classifier_dropout = (
    config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob

self.dropout = nn.Dropout(classifier_dropout
self.classifier = nn.Linear(config.hidden_size, config.num_labels
  • BertForQuestionAnswering(BertPreTrainedModel:QA任务,很多任务都可以转换为这种形式。即识别答案的开始位置和结束位置。
self.bert = BertModel(config, add_pooling_layer=False
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels

2GPT2

  • GPT2Model(GPT2PreTrainedModel:原始的GPT2模型,返回每个token的向量。

  • GPT2LMHeadModel(GPT2PreTrainedModel:进行语言模型任务。判断每一个token的下一个token是什么、

self.transformer = GPT2Model(config
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False
  • GPT2DoubleHeadsModel(GPT2PreTrainedModel:除了语言模型任务外,额外定义了一个任务:多项选择任务。比如一个问题有两个回答,一个正确回答,一个错误回答,进行二分类任务判断哪一个是正确回答。当然可以扩展到多个选项。
 config.num_labels = 1
 self.transformer = GPT2Model(config
 self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False
 self.multiple_choice_head = SequenceSummary(config

这个要看个例子:

import torch
from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel
tokenizer = GPT2Tokenizer.from_pretrained('gpt2'
model = GPT2DoubleHeadsModel.from_pretrained('gpt2'
choices = [ "Bob likes candy ; what does Bob like ?  Bag <|endoftext|>",
                   "Bob likes candy ; what does Bob like ?  Burger <|endoftext|>",
                   "Bob likes candy ; what does Bob like ?  Candy <|endoftext|>",
                  "Bob likes candy ; what does Bob like ?  Apple <|endoftext|>"]
encoded_choices = [tokenizer.encode(s for s in choices]
eos_token_location = [tokens.index(tokenizer.eos_token_id for tokens in encoded_choices]
input_ids = torch.tensor(encoded_choices.unsqueeze(0 
mc_token_ids = torch.tensor([eos_token_location] 
print(input_ids.shape
print(mc_token_ids.shape
outputs = model(input_ids, mc_token_ids=mc_token_ids
lm_prediction_scores, mc_prediction_scores = outputs[:2]
print(lm_prediction_scores.shape
print(mc_prediction_scores
"""
torch.Size([1, 4, 13]
torch.Size([1, 4]
torch.Size([1, 4, 13, 50257]
tensor([[-6.0075, -6.0649, -6.0657, -6.0585]], grad_fn=<SqueezeBackward1>
"""

Confused by GPT2DoubleHeadsModel example · Issue #1794 · huggingface/transformers (github.com

How to use GPT2DoubleHeadsModel? · Issue #3680 · huggingface/transformers (github.com

  • GPT2ForSequenceClassification(GPT2PreTrainedModel:显然,针对于文本分类任务
self.transformer = GPT2Model(config
self.score = nn.Linear(config.n_embd, self.num_labels, bias=False
  • GPT2ForTokenClassification(GPT2PreTrainedModel:针对于token分类(命名实体识别任务)
self.transformer = GPT2Model(config
if hasattr(config, "classifier_dropout" and config.classifier_dropout is not None:
    classifier_dropout = config.classifier_dropout
elif hasattr(config, "hidden_dropout" and config.hidden_dropout is not None:
    classifier_dropout = config.hidden_dropout
else:
    classifier_dropout = 0.1
    self.dropout = nn.Dropout(classifier_dropout
    self.classifier = nn.Linear(config.hidden_size, config.num_labels

举个例子:

import torch
from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel, GPT2Model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2'
model = GPT2Model.from_pretrained('gpt2'
text = [
  "Bob likes candy ; what does Bob like ?  Bag <|endoftext|>", 
  "Bob likes candy ; what does Bob like ?  Bag <|endoftext|>"
]
inputs = tokenizer(text, return_tensors="pt"
print(inputs
print(tokenizer.decode(inputs["input_ids"][0]
output = model(**inputs
print(output[0].shape
"""
{'input_ids': tensor([[18861,  7832, 18550,  2162,   644,   857,  5811,   588,  5633,   220,
         20127,   220, 50256],
        [18861,  7832, 18550,  2162,   644,   857,  5811,   588,  5633,   220,
         20127,   220, 50256]], 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}
Bob likes candy ; what does Bob like?  Bag <|endoftext|>
torch.Size([2, 13, 768]
"""

3BART

  • BartModel(BartPretrainedModel:bart的原始模型,返回解码器每个token的向量。当然还有其它可选的。

  • BartForConditionalGeneration(BartPretrainedModel:顾名思义,条件文本生成。

self.model = BartModel(config
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings
self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False

输入一般我们需要定义:input_ids(编码器的输入)、attention_mask (编码器注意力)、decoder_input_ids(解码器的输入),target_attention_mask(解码器注意力)输出一般我们使用的有两个 loss=masked_lm_loss和 logits=lm_logits。

  • BartForSequenceClassification(BartPretrainedModel:文本分类
self.model = BartModel(config
self.classification_head = BartClassificationHead(
    config.d_model,
    config.d_model,
    config.num_labels,
    config.classifier_dropout,

class BartClassificationHead(nn.Module:
    """Head for sentence-level classification tasks."""
    def __init__(
        self,
        input_dim: int,
        inner_dim: int,
        num_classes: int,
        pooler_dropout: float,
    :
        super(.__init__(
        self.dense = nn.Linear(input_dim, inner_dim
        self.dropout = nn.Dropout(p=pooler_dropout
        self.out_proj = nn.Linear(inner_dim, num_classes
    def forward(self, hidden_states: torch.Tensor -> torch.Tensor:
        hidden_states = self.dropout(hidden_states
        hidden_states = self.dense(hidden_states
        hidden_states = torch.tanh(hidden_states
        hidden_states = self.dropout(hidden_states
        hidden_states = self.out_proj(hidden_states
        return hidden_states

具体的获取logits是这么操作的:

hidden_states = outputs[0]  # last hidden state
# 找到eos_mask的位置
eos_mask = input_ids.eq(self.config.eos_token_id.to(hidden_states.device
if len(torch.unique_consecutive(eos_mask.sum(1 > 1:
    raise ValueError("All examples must have the same number of <eos> tokens."
    sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0, -1, hidden_states.size(-1[
        :, -1, :
    ]
logits = self.classification_head(sentence_representation

损失计算:

loss = None
if labels is not None:
    labels = labels.to(logits.device
    if self.config.problem_type is None:
        if self.config.num_labels == 1:
            self.config.problem_type = "regression"
        elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int:
            self.config.problem_type = "single_label_classification"
        else:
            self.config.problem_type = "multi_label_classification"
            if self.config.problem_type == "regression":
                loss_fct = MSELoss(
                if self.config.num_labels == 1:
                    loss = loss_fct(logits.squeeze(, labels.squeeze(
                else:
                    loss = loss_fct(logits, labels
                elif self.config.problem_type == "single_label_classification":
                    loss_fct = CrossEntropyLoss(
                    loss = loss_fct(logits.view(-1, self.config.num_labels, labels.view(-1
                elif self.config.problem_type == "multi_label_classification":
                    loss_fct = BCEWithLogitsLoss(
  • BartForQuestionAnswering(BartPretrainedModel:问答和之前GPT基本一致,只不过这里的输入到计算logits前的向量是解码器的隐含层向量。
config.num_labels = 2
self.num_labels = config.num_labels
self.model = BartModel(config
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output
start_logits, end_logits = logits.split(1, dim=-1
start_logits = start_logits.squeeze(-1.contiguous(
end_logits = end_logits.squeeze(-1.contiguous(
  • BartForCausalLM(BartPretrainedModel:语言模型任务,只使用bart的解码器。
config = copy.deepcopy(config
config.is_decoder = True
config.is_encoder_decoder = False
super(.__init__(config
self.model = BartDecoderWrapper(config
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False
outputs = self.model.decoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            head_mask=head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,

logits = self.lm_head(outputs[0]
>>> from transformers import AutoTokenizer, BartForCausalLM
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base"
>>> model = BartForCausalLM.from_pretrained("facebook/bart-base", add_cross_attention=False
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt"
>>> outputs = model(**inputs
>>> logits = outputs.logits
>>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
>>> list(logits.shape == expected_shape
True

Part2实操

接下来针对xxxCausalLM和xxxForConditionalGeneration,我们实际操作来更加深入的了解它们。首先需要安装一些依赖:

pip install transformers==4.28.1
pip install evaluate
pip install datasets

4使用GPT2进行观点评论生成

数据从这里下载:https://raw.githubusercontent.com/SophonPlus/ChineseNlpCorpus/master/datasets/ChnSentiCorp_htl_all/ChnSentiCorp_htl_all.csv

直接上代码:

import torch
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import (
    default_data_collator,
    get_linear_schedule_with_warmup,

from torch.utils.data import DataLoader
data_file = "./ChnSentiCorp_htl_all.csv"  # 数据文件路径,数据需要提前下载
max_length = 86
train_batch_size = 64
eval_batch_size = 64
num_epochs = 10
lr = 3e-4
# 加载数据集
dataset = load_dataset("csv", data_files=data_file
dataset = dataset.filter(lambda x: x["review"] is not None
dataset = dataset["train"].train_test_split(0.2, seed=123
model_name_or_path = "uer/gpt2-chinese-cluecorpussmall"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path
model = AutoModelForCausalLM.from_pretrained(model_name_or_path
# example = {'label': 1, 'review': '早餐太差,无论去多少人,那边也不加食品的。酒店应该重视一下这个问题了。房间本身很好。'}
def process(example:
    text = example["review"]
    # text = ["399真的很值得之前也住过别的差不多价位的酒店式公寓没有这间好厨房很像厨房很大整个格局也都很舒服早上的早餐我订的8点半的已经冷了。位置啊什么还是很好的下次还会去服务也很周到"]
    batch_size = len(text
    inputs = tokenizer(text, add_special_tokens=False, truncation=True, max_length=max_length
    inputs["labels"] = []
    for i in range(batch_size:
        input_ids = inputs["input_ids"][i]
        if len(input_ids + 1 <= max_length:
            inputs["input_ids"][i] = input_ids + [tokenizer.pad_token_id] + [0] * (max_length - len(input_ids - 1
            inputs["labels"].append(input_ids + [tokenizer.pad_token_id] + [-100] * (max_length - len(input_ids - 1
            inputs["attention_mask"][i] = [1] * len(input_ids + [0] + [0] * (max_length - len(input_ids - 1
        else:
            inputs["input_ids"][i] = input_ids[:max_length - 1] + [tokenizer.pad_token_id]
            inputs["labels"].append(inputs["input_ids"][i]
            inputs["attention_mask"][i] = [1] * max_length
        inputs["token_type_ids"][i] = [0] * max_length
        # for k, v in inputs.items(:
        #   print(k, len(v[0]
        # assert len(inputs["labels"][i] == len(inputs["input_ids"][i] == len(inputs["token_type_ids"][i] == len(inputs["attention_mask"][i] == 86
    return inputs
# process(None
train_dataset = dataset["train"].map(process, batched=True, num_proc=1, remove_columns=dataset["train"].column_names
test_dataset = dataset["test"].map(process, batched=True, num_proc=1, remove_columns=dataset["test"].column_names
train_dataloader = DataLoader(
    train_dataset, collate_fn=default_data_collator, shuffle=True, batch_size=train_batch_size, pin_memory=True

test_dataloader = DataLoader(
    test_dataset, collate_fn=default_data_collator, batch_size=eval_batch_size, pin_memory=True

# optimizer
optimizer = torch.optim.AdamW(model.parameters(, lr=lr
# lr scheduler
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=(len(train_dataloader * num_epochs,

model.cuda(
from tqdm import tqdm
for epoch in range(num_epochs:
    model.train(
    total_loss = 0
    t = tqdm(train_dataloader
    for step, batch in enumerate(t:
        for k, v in batch.items(:
            batch[k] = v.cuda(
        outputs = model(
            input_ids=batch["input_ids"],
            token_type_ids=batch["token_type_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
        
        loss = outputs.loss
        t.set_description("loss:{:.6f}".format(loss.item(
        total_loss += loss.detach(.float(
        loss.backward(
        optimizer.step(
        lr_scheduler.step(
        optimizer.zero_grad(
    train_epoch_loss = total_loss / len(train_dataloader
    model.save_pretrained("gpt2-chinese/"
    print(f"epoch:{epoch+1}/{num_epochs} loss:{train_epoch_loss}"

训练结果:

loss:2.416899: 100%|██████████| 98/98 [01:51<00:00,  1.14s/it]
epoch:1/10 loss:2.7781832218170166
loss:2.174688: 100%|██████████| 98/98 [01:54<00:00,  1.17s/it]
epoch:2/10 loss:2.3192219734191895
loss:2.123909: 100%|██████████| 98/98 [01:55<00:00,  1.17s/it]
epoch:3/10 loss:2.037835121154785
loss:1.785878: 100%|██████████| 98/98 [01:55<00:00,  1.18s/it]
epoch:4/10 loss:1.7687807083129883
loss:1.466153: 100%|██████████| 98/98 [01:55<00:00,  1.18s/it]
epoch:5/10 loss:1.524872064590454
loss:1.465316: 100%|██████████| 98/98 [01:54<00:00,  1.17s/it]
epoch:6/10 loss:1.3074666261672974
loss:1.150320: 100%|██████████| 98/98 [01:54<00:00,  1.16s/it]
epoch:7/10 loss:1.1217808723449707
loss:1.043044: 100%|██████████| 98/98 [01:53<00:00,  1.16s/it]
epoch:8/10 loss:0.9760875105857849
loss:0.790678: 100%|██████████| 98/98 [01:53<00:00,  1.16s/it]
epoch:9/10 loss:0.8597695827484131
loss:0.879025: 100%|██████████| 98/98 [01:53<00:00,  1.16s/it]
epoch:10/10 loss:0.790839433670044

可以这么进行预测:

from transformers import AutoTokenizer, GPT2LMHeadModel, TextGenerationPipeline, AutoModelForCausalLM
from datasets import load_dataset
data_file = "./ChnSentiCorp_htl_all.csv" # 数据文件路径,数据需要提前下载
dataset = load_dataset("csv", data_files=data_file
dataset = dataset.filter(lambda x: x["review"] is not None
dataset = dataset["train"].train_test_split(0.2, seed=123
model_name_or_path = "uer/gpt2-chinese-cluecorpussmall"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path
model = AutoModelForCausalLM.from_pretrained("./gpt2-chinese/"
text_generator = TextGenerationPipeline(model, tokenizer  
import random
examples = dataset["train"]
example = random.choice(examples
text = example["review"]
print(text
print(text[:10]
text_generator(text[:10], 
        max_length=100, 
        do_sample=False, 
        top_p=0.8, 
        repetition_penalty=10.0,
        temperature=0.95,
        eos_token_id=0, 
        
"""
第一次住在这里儿,我针对大家的意见,特别关注了一下,感觉如下吧!1、标准间虽然有点旧但很干净,被子盖得很舒服,也很暖和,卫生间也蛮大的,因是在商业中心离很多还算很近。2、酒店服务还算可以,没有像这里说的那样,入住时,退房时也挺快的,总的来说我很满意。3、早餐也还可以,环境也不错,有点江南的感觉;菜品种品也不少,挺可口。4、可能是在市或者离火车站的距离很近,稍微有点“热闹”,来找我办事的人不方便停车,但还好这里有地下停车场。总体来说,我感觉很不错,值得推荐!
第一次住在这里儿,我
[{'generated_text': '第一次住在这里儿,我 感 觉 很 温 馨 。 房 间 宽 敞 、 干 净 还 有 水 果 送 ( 每 人 10 元 ) ; 饭 菜 也 不 错 ! 价 格 合 理 经 济 实 惠 .'}]
"""

我们需要注意的几点:

  • 不同模型使用的tokenizer是不一样的,需要注意它们的区别,尤其是pad_token_id和eos_token_id。eos_token_id常常用于标识生成文本的结尾。

  • 有一些中文的生成预训练模型使用的还是Bert的tokenizer,在进行token化的时候,通过指定add_special_tokens=False来避免添加[CLS]和[SEP]。

  • BertTokenizer的eos_token_id为None,这里我们用[PAD]视为生成结束的符号,其索引为0.当然,你也可以设置它为词表里面的特殊符号,比如[SEP]。

  • 对于不需要计算损失的token,我们将其标签设置为-100。

  • 我们的labels和input_ids为什么是一样的,不是说根据上一个词生成下一个词吗?这是因为模型里面帮我们处理了,见代码:

shift_logits = lm_logits[..., :-1, :].contiguous(
shift_labels = labels[..., 1:].contiguous(
  • 进行预测有三种方式,控制文本生成的多样性有很多参数可以选择,具体刚兴趣可参考最后面的链接。

5使用BART进行对联生成

数据从这里下载:https://www.modelscope.cn/datasets/minisnow/couplet_samll.git

直接看代码:

import json
import pandas as pd
import numpy as np
# import lawrouge
from transformers import BertTokenizer, BartForConditionalGeneration, Text2TextGenerationPipeline, pipeline
from datasets import load_dataset, Dataset
from transformers import default_data_collator
import torch
from tqdm import tqdm
from datasets import load_dataset
from transformers import (
    default_data_collator,
    get_linear_schedule_with_warmup,

from torch.utils.data import DataLoader
# ============================= 
# 加载数据
train_path = "couplet_samll/train.csv"
train_dataset = Dataset.from_csv(train_path
test_path = "couplet_samll/test.csv"
test_dataset = Dataset.from_csv(test_path
max_len = 24
train_batch_size = 64
eval_batch_size = 64
lr = 3e-4
num_epochs = 1
# 转换为模型需要的格式
def tokenize_dataset(tokenizer, dataset, max_len:
  def convert_to_features(batch:
    text1 = batch["text1"]
    text2 = batch["text2"]
    inputs = tokenizer.batch_encode_plus(
      text1,
      max_length=max_len,
      padding="max_length",
      truncation=True,
    
    targets = tokenizer.batch_encode_plus(
      text2,
      max_length=max_len,
      padding="max_length",
      truncation=True,
    
    outputs = {
      "input_ids": inputs["input_ids"],
      "attention_mask": inputs["attention_mask"],
      "target_ids": targets["input_ids"],
      "target_attention_mask": targets["attention_mask"]
    }
    return outputs
  
  dataset = dataset.map(convert_to_features, batched=True
  # Set the tensor type and the columns which the dataset should return
  columns = ['input_ids', 'target_ids', 'attention_mask', 'target_attention_mask']
  dataset.with_format(type='torch', columns=columns
  dataset = dataset.rename_column('target_ids', 'labels'
  dataset = dataset.rename_column('target_attention_mask', 'decoder_attention_mask'
  dataset = dataset.remove_columns(['text1', 'text2']
  return dataset
tokenizer = BertTokenizer.from_pretrained("fnlp/bart-base-chinese"
train_data = tokenize_dataset(tokenizer, train_dataset, max_len
test_data = tokenize_dataset(tokenizer, test_dataset, max_len
train_dataset = train_data
train_dataloader = DataLoader(
    train_dataset, collate_fn=default_data_collator, shuffle=True, batch_size=train_batch_size, pin_memory=True

test_dataset = test_data
test_dataloader = DataLoader(
    test_dataset, collate_fn=default_data_collator, batch_size=eval_batch_size, pin_memory=True

# optimizer
model = BartForConditionalGeneration.from_pretrained("fnlp/bart-base-chinese"
optimizer = torch.optim.AdamW(model.parameters(, lr=lr
# lr scheduler
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=(len(train_dataloader * num_epochs,

model.cuda(
from tqdm import tqdm
for epoch in range(num_epochs:
    model.train(
    total_loss = 0
    t = tqdm(train_dataloader
    for step, batch in enumerate(t:
        for k, v in batch.items(:
            batch[k] = v.cuda(
        outputs = model(**batch
        loss = outputs.loss
        t.set_description("loss:{:.6f}".format(loss.item(
        total_loss += loss.detach(.float(
        loss.backward(
        optimizer.step(
        lr_scheduler.step(
        optimizer.zero_grad(
    train_epoch_loss = total_loss / len(train_dataloader
    model.save_pretrained("bart-couplet/"
    tokenizer.save_pretrained("bart-couplet/"
    print(f"epoch:{epoch+1}/{num_epochs} loss:{train_epoch_loss}"

结果:

loss:1.593506: 100%|██████████| 4595/4595 [33:28<00:00,  2.29it/s]
epoch:1/1 loss:1.76453697681427

我们可以这么预测:

from transformers import Text2TextGenerationPipeline
model_path = "bart-couplet"
# model_path = "fnlp/bart-base-chinese"
model = BartForConditionalGeneration.from_pretrained(model_path
tokenizer = BertTokenizer.from_pretrained(model_path
generator = Text2TextGenerationPipeline(model=model, tokenizer=tokenizer
max_len = 24
test_path = "couplet_samll/test.csv"
test_data = pd.read_csv(test_path
texts = test_data["text1"].values.tolist(
labels = test_data["text2"].values.tolist(
results = generator(texts, max_length=max_len, eos_token_id=102, pad_token_id=0, do_sample=True
for text, label, res in zip(texts, labels, results:
  print("上联:", text
  print("真实下联:", label
  print("预测下联:", "".join(res["generated_text"].split(" "
  print("="*100
    
"""
上联: 几帧山水关秋路
真实下联: 无奈胭脂点绛唇
预测下联: 天高云淡月光明
====================================================================================================
上联: 许多心事懒收拾
真实下联: 大好青春莫撂荒
预测下联: 何妨明月照寒窗
====================================================================================================
上联: 谁同执手人间老
真实下联: 自愿并肩化外游
预测下联: 心中有梦月当头
====================================================================================================
上联: 画地为牢封自步
真实下联: 齐天大圣悟空行
预测下联: 不妨一世好清闲
====================================================================================================
上联: 布谷携春临五岳
真实下联: 流莺送喜到千家
预测下联: 万家灯火庆丰年
====================================================================================================
上联: 冤家宜解不宜结
真实下联: 穷寇定歼必定追
预测下联: 不因风雨误春秋
====================================================================================================
上联: 汪伦情义人间少
真实下联: 法律条文格外繁
预测下联: 一江春水向东流
====================================================================================================
上联: 泼墨吟诗,银发人生添雅兴
真实下联: 手机短信,古稀老叟逐新潮
预测下联: 春风得意,万里千帆逐浪高
====================================================================================================
上联: 刊岫展屏山,云凝罨画
真实下联: 平湖环镜槛,波漾空明
预测下联: 千年古邑,百花芳草淹春
====================================================================================================
上联: 且向人间赊一醉
真实下联: 直如岛外泛孤舟
预测下联: 春风得意乐逍遥
====================================================================================================
"""

需要注意的地方:

  • 这里我们的输入不再是单条文本,而是文本对。
  • 我们需要构造编码器Input_ids,编码器attention_mask,解码器input_ids,解码器attention_mask。
  • 这里使用了一个技巧:采样生成,设置do_sample=True。如果你尝试设置它为False,你会发现生成的结果可能不是那么好。
  • 同样的这里使用的还是Bert的tokenizer,这里进行tokenizer的时候我们保留了bert的[CLS]和[SEP]。为了更直观的理解,我们使用另一种更直接的方法来生成结果:
model = BartForConditionalGeneration.from_pretrained(model_path
model = model.to("cuda"
model.eval(
inputs = tokenizer(
        texts,
        padding="max_length",
        truncation=True,
        max_length=max_len,
        return_tensors="pt",
    
input_ids = inputs.input_ids.to(model.device
attention_mask = inputs.attention_mask.to(model.device
# 生成
outputs = model.generate(input_ids, 
              attention_mask=attention_mask, 
              max_length=max_len, 
              do_sample=True, 
              pad_token_id=0,
              eos_token_id=102
# 将token转换为文字
output_str = tokenizer.batch_decode(outputs, skip_special_tokens=False
output_str = [s.replace(" ","" for s in output_str]
for text, label, pred in zip(texts, labels, output_str:
  print("上联:", text
  print("真实下联:", label
  print("预测下联:", pred
  print("="*100

结果:

上联: 几帧山水关秋路
真实下联: 无奈胭脂点绛唇
预测下联: [SEP][CLS]春风送暖柳含烟[SEP][PAD][PAD][PAD][PAD][PAD]
====================================================================================================
上联: 许多心事懒收拾
真实下联: 大好青春莫撂荒
预测下联: [SEP][CLS]无私奉献为人民[SEP][PAD][PAD][PAD][PAD][PAD]
====================================================================================================
上联: 谁同执手人间老
真实下联: 自愿并肩化外游
预测下联: [SEP][CLS]清风明月是知音[SEP][PAD][PAD][PAD][PAD][PAD]
====================================================================================================
上联: 画地为牢封自步
真实下联: 齐天大圣悟空行
预测下联: [SEP][CLS]月明何处不相逢[SEP][PAD][PAD][PAD][PAD][PAD]
====================================================================================================
上联: 布谷携春临五岳
真实下联: 流莺送喜到千家
预测下联: [SEP][CLS]一壶老酒醉春风[SEP][PAD][PAD][PAD][PAD][PAD]
====================================================================================================
上联: 冤家宜解不宜结
真实下联: 穷寇定歼必定追
预测下联: [SEP][CLS]风流人物不虚名[SEP][PAD][PAD][PAD][PAD][PAD]
====================================================================================================
上联: 汪伦情义人间少
真实下联: 法律条文格外繁
预测下联: [SEP][CLS]万里江山万里春[SEP][PAD][PAD][PAD][PAD][PAD]
====================================================================================================
上联: 泼墨吟诗,银发人生添雅兴
真实下联: 手机短信,古稀老叟逐新潮
预测下联: [SEP][CLS]和谐社会,和谐和谐幸福家[SEP]
====================================================================================================
上联: 刊岫展屏山,云凝罨画
真实下联: 平湖环镜槛,波漾空明
预测下联: [SEP][CLS]天下无双,人寿年丰[SEP][PAD][PAD][PAD]
====================================================================================================
上联: 且向人间赊一醉
真实下联: 直如岛外泛孤舟
预测下联: [SEP][CLS]不知何处有闲人[SEP][PAD][PAD][PAD][PAD][PAD]
====================================================================================================
  • 我们设置skip_special_tokens=False,在生成时不忽略特殊token。
  • 以"无奈胭脂点绛唇"为例。输入[SEP],预测得到[CLS],输入[SEP][CLS]得到正常的文本,最后以[SEP]结尾。因为我们的encoder_input_ids和decoder_input_ids都是加了特殊符号的。当然你可以不加或者自定义使用其它的特殊符号。

到这里,你已经了解了transformers库中自带的模型及相关的一些任务了,特别是针对生成模型有了更深一层的了解,赶紧去试试吧。

最后附上相关的一些知识:

https://zhuanlan.zhihu.com/p/624845975

Part3参考

transformers.models.auto.modeling_auto — transformers 4.4.2 documentation (huggingface.co

编程笔记 » 你还弄不清xxxForCausalLM和xxxForConditionalGeneration吗

赞同 (36) or 分享 (0)
游客 发表我的评论   换个身份
取消评论

表情
(0)个小伙伴在吐槽