深度学习笔记(8)预训练模型

深度学习笔记(8)预训练模型

文章目录

  • 深度学习笔记(8)预训练模型
  • 一、预训练模型构建
  • 一、微调模型,训练自己的数据
    • 1.导入数据集
    • 2.数据集处理方法
    • 3.完形填空训练
  • 使用分词器将文本转换为模型的输入格式
  • 参数 return_tensors="pt" 表示返回PyTorch张量格式
  • 执行模型预测


一、预训练模型构建

加载模型和之前一样,用别人弄好的

# 导入warnings模块,用于忽略后续代码中可能出现的警告信息
import warnings
# 设置warnings模块,使其忽略所有警告
warnings.filterwarnings("ignore")

# 从transformers库中导入AutoModelForMaskedLM类,该类用于预训练的掩码语言模型
from transformers import AutoModelForMaskedLM

# 指定模型检查点,这里使用的是distilbert-base-uncased模型
model_checkpoint = "distilbert-base-uncased"
# 使用from_pretrained方法加载预训练的模型,该方法将从指定的检查点加载模型
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)

咱们的任务就是去预测MASK到底是个啥

text = "This is a great "  # 定义一个文本字符串

# 从指定的模型检查点加载分词器
# model_checkpoint 是之前定义的模型检查点路径,用于加载与模型配套的分词器
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

# 使用分词器将文本转换为模型的输入格式
# 参数 return_tensors="pt" 表示返回PyTorch张量格式
inputs = tokenizer(text, return_tensors="pt")

# inputs 现在是一个包含模型输入的张量或字典,可以用于模型推理
{'input_ids': tensor([[ 101, 2023, 2003, 1037, 2307,  103, 1012,  102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}

下面的代码可以看到mask的id是103

tokenizer.mask_token_id
103

一、微调模型,训练自己的数据

1.导入数据集

from datasets import load_dataset  # 从datasets库中导入load_dataset函数

imdb_dataset = load_dataset("imdb")  # 使用load_dataset函数加载IMDB数据集

这段代码是使用 datasets 库来加载 IMDB 数据集。IMDB 数据集是一个用于情感分析的经典数据集,包含了两类电影评论:正面和负面。
本身是带标签的 正面和负面

  • 0 表示negative
  • 1表示positive

先查看下数据集的数据

 sample = imdb_dataset["train"].shuffle(seed=42).select(range(3))  # 从训练数据中随机选择3个样本
 for row in sample:
    print(f"\n'>>> Review: {row['text']}'")  # 打印样本的文本内容
    print(f"'>>> Label: {row['label']}'")   # 打印样本的标签

其中一个如下

'>>> Review: This movie is a great. The plot is very true to the book which is a classic written by Mark Twain. The movie starts of with a scene where Hank sings a song with a bunch of kids called "when you stub your toe on the moon" It reminds me of Sinatra's song High Hopes, it is fun and inspirational. The Music is great throughout and my favorite song is sung by the King, Hank (bing Crosby) and Sir "Saggy" Sagamore. OVerall a great family movie or even a great Date movie. This is a movie you can watch over and over again. The princess played by Rhonda Fleming is gorgeous. I love this movie!! If you liked Danny Kaye in the Court Jester then you will definitely like this movie.'
'>>> Label: 1'

但是我们要做完形填空,标签是没用的

2.数据集处理方法

这里文本长度要统一

  • 计算每一个文本的长度(word_ids)
  • 指定chunk_size,然后将所有数据按块进行拆分,比如每块128个,句子是700字节,要分成128,128,。。。。这种

先定义个函数,这样下面使用可以直接调用

def tokenize_function(examples):
    # 调用分词器(tokenizer)的函数,传入输入的文本数据集
    result = tokenizer(examples["text"])
    
    # 如果分词器支持快速模式,则生成单词索引(word_ids)
    if tokenizer.is_fast:
        result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]
    
    # 返回转换后的结果
    return result

进行文本处理

tokenized_datasets = imdb_dataset.map(
    tokenize_function,  # 应用tokenize_function函数到每个样本
    batched=True,       # 是否将数据分成批次处理
    remove_columns=["text", "label"]  # 要从数据集中移除的列
)

首先,imdb_dataset.map() 方法被用来应用 tokenize_function 函数到 imdb_dataset 的 train 部分。这个方法会对数据集中的每个样本应用指定的函数,并返回一个新的数据集,其中包含应用函数后的结果。

batched=True 参数告诉 map 方法将输入数据分成批次进行处理。这通常是为了提高效率,尤其是在处理大型数据集时。

remove_columns=[“text”, “label”] 参数告诉 map 方法在处理数据时移除指定的列。在这个例子中,它移除了 text 和 label 列,因为 text 列已经被处理为模型的输入,而 label 列不再需要,因为咱们是完形填空任务,不需要标签。

然后进行切分

tokenizer.model_max_length
chunk_size = 128

tokenizer.model_max_length 是一个属性,它表示模型能够接受的最大输入长度。这个属性通常用于序列标注任务,以确保输入的长度不超过模型的最大接受长度。

chunk_size 是一个参数,用于指定将输入序列分割成小块的大小。这个参数通常用于处理过长的输入序列,以便将其分割成多个小块,然后分别处理这些小块。
因为上限是512 所以你切分要是64 128 256这种

切分的时候也可以先看下文本长度

# 看看每一个都多长
tokenized_samples = tokenized_datasets["train"][:3]

for idx, sample in enumerate(tokenized_samples["input_ids"]):
    print(f"'>>> Review {idx} length: {len(sample)}'")
'>>> Review 0 length: 363'
'>>> Review 1 length: 304'
'>>> Review 2 length: 133'

这里看出128切分比较合适
先拿着三个试试

concatenated_examples = {
    k: sum(tokenized_samples[k], []) for k in tokenized_samples.keys()#计算拼一起有多少个,
}
total_length = len(concatenated_examples["input_ids"])
print(f"'>>> Concatenated reviews length: {total_length}'")
'>>> Concatenated reviews length: 800'

下面进行切分

chunks = {
    # 使用字典推导式(dict comprehension)创建一个新的字典,其中键是序列的键(k),值是分割后的块
    k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
    # 遍历concatenated_examples字典中的每个键值对
    for k, t in concatenated_examples.items()
}

for chunk in chunks["input_ids"]:
    # 打印每个chunk的长度
    print(f"'>>> Chunk length: {len(chunk)}'")
'>>> Chunk length: 128'
'>>> Chunk length: 128'
'>>> Chunk length: 128'
'>>> Chunk length: 128'
'>>> Chunk length: 128'
'>>> Chunk length: 128'
'>>> Chunk length: 32'

最后那个不够数,直接给他丢弃

def group_texts(examples):
    # 将所有的文本实例拼接到一起
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    
    # 计算拼接后的总长度
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    
    # 使用整除运算符(//)计算每个chunk的长度,然后乘以chunk_size,以确保不会出现多余的文本
    total_length = (total_length // chunk_size) * chunk_size
    
    # 根据计算出的总长度,对拼接后的文本进行切分
    result = {
        # 使用字典推导式(dict comprehension)创建一个新的字典,其中键是原始字典的键(k),值是切分后的块
        k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
        # 遍历concatenated_examples字典中的每个键值对
        for k, t in concatenated_examples.items()
    }
    
    # 如果完型填空任务需要使用标签,则将标签复制到结果字典中
    result["labels"] = result["input_ids"].copy()
    
    # 返回分割后的结果
    return result

使用整除运算符(//)计算每个chunk的长度,然后乘以chunk_size这个就是说如果小于128 整除后就得0,就没了

  train: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 61291
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 59904
    })
    unsupervised: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 122957
    })
})

会发现数据量大了

3.完形填空训练

from transformers import DataCollatorForLanguageModeling  # 从transformers库中导入DataCollatorForLanguageModeling类

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)  # 创建一个数据收集器

然后再举个例子看看mask长啥样

samples = [lm_datasets["train"][i] for i in range(2)]  # 从训练数据中选择2个样本
for sample in samples:
    _ = sample.pop("word_ids")  # 移除样本中的"word_ids"键,因为发现没啥用
    print(sample)
 for chunk in data_collator(samples)["input_ids"]:
 print(f"\n'>>> {tokenizer.decode(chunk)}'")#tokenizer.decode(chunk)是为了让被隐藏的更明显点,所以直接decode出来
 print(len(chunk))

训练过程

from transformers import TrainingArguments  # 从transformers库中导入TrainingArguments类

batch_size = 64  # 定义训练和评估时的批次大小
# 计算每个epoch打印结果的步数
logging_steps = len(downsampled_dataset["train"]) // batch_size
model_name = model_checkpoint.split("/")[-1]  # 从模型检查点路径中提取模型名称

training_args = TrainingArguments(
    output_dir=f"{model_name}-finetuned-imdb",  # 指定输出目录,其中包含微调后的模型和日志文件
    overwrite_output_dir=True,  # 是否覆盖现有的输出目录
    evaluation_strategy="epoch",  # 指定评估策略,这里为每个epoch评估一次
    learning_rate=2e-5,  # 学习率
    weight_decay=0.01,  # 权重衰减系数
    per_device_train_batch_size=batch_size,  # 每个GPU的训练批次大小
    per_device_eval_batch_size=batch_size,  # 每个GPU的评估批次大小
    logging_steps=logging_steps,  # 指定每个epoch打印结果的步数
    num_train_epochs=1,  # 指定训练的epoch数量
    save_strategy='epoch',  # 指定保存策略,这里为每个epoch保存一次模型
)

生成的模型在这里插入图片描述

from transformers import Trainer  # 从transformers库中导入Trainer类

trainer = Trainer(
    model=model,  # 指定要训练的模型实例
    args=training_args,  # 指定训练参数对象
    train_dataset=downsampled_dataset["train"],  # 指定训练数据集
    eval_dataset=downsampled_dataset["test"],  # 指定评估数据集
    data_collator=data_collator,  # 指定数据收集器
)

评估标准使用困惑度

import math  # 导入math模块,用于计算对数和指数

eval_results = trainer.evaluate()  # 使用Trainer的evaluate方法评估模型

print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}")  # 打印评估结果中的对数

人话就是你不得在mask那挑啥词合适吗,平均挑了多少个才能答对

训练模型

trainer.train() 
eval_results = trainer.evaluate()
print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

发现困惑度降低了

from transformers import AutoModelForMaskedLM

model_checkpoint  = "./distilbert-base-uncased"
model = AutoModelForMaskedLM.from_pretrained("./distilbert-base-uncased-finetuned-imdb/checkpoint-157")

加载自己的模型
import torch # 导入torch库,用于处理张量

使用分词器将文本转换为模型的输入格式

参数 return_tensors=“pt” 表示返回PyTorch张量格式

inputs = tokenizer(text, return_tensors=“pt”)

执行模型预测

# model 是一个预训练的BERT模型实例
# **inputs 表示将inputs字典中的所有键值对作为关键字参数传递给model
token_logits = model(**inputs).logits

# 找到遮蔽词在输入中的索引
# inputs["input_ids"] 是模型输入的词汇索引张量
# tokenizer.mask_token_id 是遮蔽词的词汇索引
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]

# 获取遮蔽词的预测logits
# mask_token_index 是遮蔽词在输入中的索引张量
# token_logits 是模型输出的预测logits张量
mask_token_logits = token_logits[0, mask_token_index, :]

# 找到前5个最可能的替换词
# torch.topk 函数用于找到最大k个值及其索引
# dim=1 表示在第二个维度(即词汇维度)上进行排序
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

# 打印每个最可能的替换词及其替换后的文本
for token in top_5_tokens:
    # 使用 tokenizer.decode 方法将索引转换为文本
    # text 是原始文本
    # tokenizer.mask_token 是遮蔽词
    # [token] 是替换词的索引列表
    print(f"'>>> {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}'")
'>>> This is a great deal.'
'>>> This is a great idea.'
'>>> This is a great adventure.'
'>>> This is a great film.'
'>>> This is a great movie.'

可以看到结果和上面的通用的不一样了,这里是film movie这些针对你的训练数据的了--------

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/879759.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Java | Leetcode Java题解之第417题太平洋大西洋水流问题

题目&#xff1a; 题解&#xff1a; class Solution {static int[][] dirs {{-1, 0}, {1, 0}, {0, -1}, {0, 1}};int[][] heights;int m, n;public List<List<Integer>> pacificAtlantic(int[][] heights) {this.heights heights;this.m heights.length;this.n…

【JSrpc破解前端加密问题】

目录 一、背景 二、项目介绍 三、JSrpc 处理前端加密步骤 一、背景 解决日常渗透测试、红蓝对抗中的前端密码加密问题&#xff0c;让你的爆破更加丝滑&#xff1b;降低js逆向加密的难度&#xff0c;降低前端加密逻辑分析工作量和难度。 二、项目介绍 运行服务器程序和js脚本…

springCloud(一)注册中心

1.Eureka 要是user-service服务有多个&#xff0c;order-service该怎么调用&#xff1f; 这就需要用到 注册中心 了 。 1.1 搭建Eureka服务 1. pom引入依赖 <dependencies><!--eureka服务端--><dependency><groupId>org.springframework.cloud</gr…

线程局部变量

开发线程的步骤 为什么要学 ThreadLocal 就是为了防止开发随意选库&#xff0c;设置线程局部变量 因为初始化随着项目启动-创建了连接池&#xff0c;但目前getinfo和login都是走从库&#xff0c;没有分开 所以在service层方法运行时&#xff0c;用ReqAop代码提前算出此方法走…

将有序数组——>二叉搜索树

给你一个整数数组 nums &#xff0c;其中元素已经按 升序 排列&#xff0c;请你将其转换为一棵平衡二叉搜索树。 示例 1&#xff1a; 输入&#xff1a;nums [-10,-3,0,5,9] 输出&#xff1a;[0,-3,9,-10,null,5] 解释&#xff1a;[0,-10,5,null,-3,null,9] 也将被视为正确答案…

Modbus_RTU和Modbus库

目录 一.Modbus_RTU 1. 与Modbus TCP的区别 2. Modbus RTU特点 3. Modbus RTU协议格式 4. 报文详解 5. 代码实现RTU通信 1. 打开模拟的RTU从机 2. linux端使用代码实现和串口连接 2.1. 框架搭建 2.2 代码 二.Modbus库 1.库函数 一.Modbus_RTU 1. 与Modbus T…

C++初阶学习第六弹------标准库中的string类

目录 一.标准库中的string类 二.string的常用接口函数 2.1string类对象的构造 2.2 string的容量操作 2.3 string类的访问与遍历 2.4 string类对象的修改 2.5 string类常用的非成员函数 三、总结 一.标准库中的string类 可以简单理解成把string类理解为变长的字符数组&#x…

c++234继承

#include<iostream> using namespace std;//public 修饰的成员便俩个和方法都能使用 //protected&#xff1a;类的内部 在继承的子类中可使用 class Parents { public:int a;//名字 protected:int b;//密码 private:int c;//情人public:void printT(){cout << &quo…

C:字符串函数(完)-学习笔记

目录 前言&#xff1a; 1、strstr 1.1 strstr的使用 4.2 strstr的模拟实现 5、strtok 5.1 strtok函数的介绍 5.2 strtok函数的使用 6、strerror 前言&#xff1a; 这篇文章将介绍strstr函数&#xff0c;strtok函数&#xff0c;strerror函数 1、strstr 1.1 strstr的使用…

RabbitMQ 高级特性——持久化

文章目录 前言持久化交换机持久化队列持久化消息持久化 前言 前面我们学习了 RabbitMQ 的高级特性——消息确认&#xff0c;消息确认可以保证消息传输过程的稳定性&#xff0c;但是在保证了消息传输过程的稳定性之后&#xff0c;还存在着其他的问题&#xff0c;我们都知道消息…

Linux内核结构

Linux内核结构 文章目录 Linux内核结构一、Linux内核结构介绍1.1 总体结构&#xff1a;1.2 Linux内核结构框图&#xff1a; 二、图解Linux系统架构三、shell3.1 shell的含义&#xff1a;3.2 shell的作用&#xff1a;3.3 shell的类型&#xff1a;3.4 shell的使用&#xff1a;3.5…

安泰电压放大器设计方法是什么样的

电压放大器是电子领域中常用的设备&#xff0c;用于将低电压信号放大成高电压信号。电压放大器在信号处理、通信系统、仪器测量、控制系统、医疗设备和研究和实验室等领域都有着广泛的应用。 电压放大器的设计方法主要包括选择合适的放大器拓扑结构、选择适当的放大器参数以及进…

72v-80V降5V1.5A恒压降压WT6035

72v-80V降5V1.5A恒压降压WT6035 WT6035 是一款高压降压开关稳压器&#xff0c;可用于将 72V - 80V 的电压降为 5V、1.5A 的恒压输出&#xff0c;以下是一些关于它的特点及应用注意事项&#xff1a; 芯片特点&#xff1a; 宽电压输入范围&#xff1a;输入电压范围为 5V 至 100V…

设计模式之命令模式:从原理到实战,深入解析及源码应用

&#x1f3af; 设计模式专栏&#xff0c;持续更新中 欢迎订阅&#xff1a;JAVA实现设计模式 &#x1f6e0;️ 希望小伙伴们一键三连&#xff0c;有问题私信都会回复&#xff0c;或者在评论区直接发言 命令模式 什么是命令模式&#xff1f; 命令模式&#xff08;Command Pattern…

sensitive-word 敏感词 v0.20.0 数字全部匹配,而不是部分匹配

敏感词系列 sensitive-word-admin 敏感词控台 v1.2.0 版本开源 sensitive-word-admin v1.3.0 发布 如何支持分布式部署&#xff1f; 01-开源敏感词工具入门使用 02-如何实现一个敏感词工具&#xff1f;违禁词实现思路梳理 03-敏感词之 StopWord 停止词优化与特殊符号 04-…

《微信小程序实战(3) · 推广海报制作》

&#x1f4e2; 大家好&#xff0c;我是 【战神刘玉栋】&#xff0c;有10多年的研发经验&#xff0c;致力于前后端技术栈的知识沉淀和传播。 &#x1f497; &#x1f33b; CSDN入驻不久&#xff0c;希望大家多多支持&#xff0c;后续会继续提升文章质量&#xff0c;绝不滥竽充数…

VISIA 皮肤检测

费用:自费158元 不能医保报销 先清洁肌肤,然后做一个皮肤检测. 1200万像素高清摄像头,一个白光,一个偏正光,还有一个紫外光,三种模式,分析面部情况. 8张图 反应皮肤情况应用: 在医美前和医美一次修复完成后,皮肤情况对比. 数值越高 越好 斑点图: 皱纹图: 分数比较低的话,皮肤…

SpringBoot教程(三十) | SpringBoot集成Shiro(权限框架)

SpringBoot教程&#xff08;三十&#xff09; | SpringBoot集成Shiro&#xff08;权限框架&#xff09; 一、 什么是Shiro二、Shiro 组件核心组件其他组件 三、流程说明shiro的运行流程 四、SpringBoot 集成 Shiro1. 添加 Shiro 相关 maven2. 添加 其他 maven3. 设计数据库表4.…

268页PPT大型集团智慧工厂信息化顶层架构设计(2024版)

智能制造装备是高端制造业的关键&#xff0c;通过整合智能传感、控制、AI等技术&#xff0c;具备了信息感知、分析规划等智能化功能&#xff0c;能显著提升加工质量、效率和降低成本。该装备是先进制造、信息、智能技术的深度融合。其原理主要包括物联网集成、大数据分析与人工…

【数据结构与算法 | 灵神题单 | 合并链表篇】力扣2, 21, 445, 2816

1. 力扣2&#xff1a;两数相加 1.1 题目&#xff1a; 给你两个 非空 的链表&#xff0c;表示两个非负的整数。它们每位数字都是按照 逆序 的方式存储的&#xff0c;并且每个节点只能存储 一位 数字。 请你将两个数相加&#xff0c;并以相同形式返回一个表示和的链表。 你可…