Python 自己训练chatGPT,实例代码如下;简单易懂的训练chatGPT,模板实例;自己训练chatGPT

news/2024/7/20 17:40:12 标签: python, chatgpt, NLP, LLMs, 大模型, 大语言模型, 人工智能

代码实例:

比较简单的示例,其它gpt架构相关知识和代码移步专栏其它文章。 

python">from torchtext.datasets import WikiText2 # 导入WikiText2
from torchtext.data.utils import get_tokenizer # 导入Tokenizer分词工具
from torchtext.vocab import build_vocab_from_iterator # 导入Vocabulary工具
tokenizer = get_tokenizer("basic_english") # 定义数据预处理所需的tokenizer
train_iter = WikiText2(split='train') # 加载WikiText2数据集的训练部分
# 定义一个生成器函数,用于将数据集中的文本转换为tokens
def yield_tokens(data_iter):
    for item in data_iter:
        yield tokenizer(item)
# 创建词汇表,包括特殊tokens:"<pad>", "<sos>", "<eos>"
vocab = build_vocab_from_iterator(yield_tokens(train_iter), 
                                  specials=["<pad>", "<sos>", "<eos>"])
vocab.set_default_index(vocab["<pad>"])

# 打印词汇表信息
print("词汇表大小:", len(vocab))
print("词汇示例(word to index):", 
      {word: vocab[word] for word in ["<pad>", "<sos>", "<eos>", "the", "apple"]})

python">import torch #导入torch
from torch.utils.data import Dataset #导入Dataset

class ChatDataset(Dataset):
    def __init__(self, file_path, tokenizer, vocab):
        self.tokenizer = tokenizer #分词器
        self.vocab = vocab #词汇表
        self.input_data, self.target_data = self.load_and_process_data(file_path)
    def load_and_process_data(self, file_path):        
        with open(file_path, "r") as f:
            lines = f.readlines() # 打开文件,读取每一行数据
        input_data, target_data = [], []
        for i, line in enumerate(lines):
            if line.startswith("User:"): # 移除 "User: " 前缀,构建输入序列
                tokens = self.tokenizer(line.strip()[6:])  
                tokens = ["<sos>"] + tokens + ["<eos>"]
                indices = [self.vocab[token] for token in tokens]
                input_data.append(torch.tensor(indices, dtype=torch.long))
            elif line.startswith("AI:"): # 移除 "AI: " 前缀,构建目标序列
                tokens = self.tokenizer(line.strip()[4:])  
                tokens = ["<sos>"] + tokens + ["<eos>"]
                indices = [self.vocab[token] for token in tokens]
                target_data.append(torch.tensor(indices, dtype=torch.long))
        return input_data, target_data
    def __len__(self): #数据集长度
        return len(self.input_data) 
    def __getitem__(self, idx): #根据索引获取数据样本
        return self.input_data[idx], self.target_data[idx] 

file_path = "chat.txt" # 加载chat.txt数据集
chat_dataset = ChatDataset(file_path, tokenizer, vocab)

for i in range(3): # 打印几个样本数据
    input_sample, target_sample = chat_dataset[i]
    print(f"Sample {i + 1}:")
    print("Input Data: ", input_sample)
    print("Target Data: ", target_sample)
    print("-" * 50)

python">from torch.utils.data import DataLoader # 导入Dataloader
# 定义pad_sequence函数,用于将一批序列补齐到相同长度
def pad_sequence(sequences, padding_value=0, length=None):
    # 计算最大序列长度,如果length参数未提供,则使用输入序列中的最大长度
    max_length = max(len(seq) for seq in sequences) if length is None else length    
    # 创建一个具有适当形状的全零张量,用于存储补齐后的序列
    result = torch.full((len(sequences), max_length), padding_value, dtype=torch.long)    
    # 遍历序列,将每个序列的内容复制到结果张量中
    for i, seq in enumerate(sequences):
        end = len(seq)
        result[i, :end] = seq[:end]
    return result

# 定义collate_fn函数,用于将一个批次的数据整理成适当的形状
def collate_fn(batch):
    # 从批次中分离源序列和目标序列
    sources, targets = zip(*batch)    
    # 计算批次中的最大序列长度
    max_length = max(max(len(s) for s in sources), max(len(t) for t in targets))    
    # 使用pad_sequence函数补齐源序列和目标序列
    sources = pad_sequence(sources, padding_value=vocab["<pad>"], length=max_length)
    targets = pad_sequence(targets, padding_value=vocab["<pad>"], length=max_length)    
    # 返回补齐后的源序列和目标序列
    return sources, targets

# 创建Dataloader
batch_size = 2
chat_dataloader = DataLoader(chat_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

# 检查Dataloader输出
for input_batch, target_batch in chat_dataloader:
    print("Input batch tensor size:", input_batch.size())
    print("Target batch tensor size:", target_batch.size())
    break

python">from GPT_Model import GPT #导入GPT模型的类(这是我们自己制作的)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = GPT(len(vocab), max_seq_len=256, n_layers=6).to(device) #创建模型示例
model.load_state_dict(torch.load('trained_model_2023-05-05_14-08-24.pt')) #加载模型
# model.eval()
python">import torch.nn as nn #导入nn
import torch.optim as optim #导入优化器
criterion = nn.CrossEntropyLoss(ignore_index=vocab["<pad>"]) #损失函数
optimizer = optim.Adam(model.parameters(), lr=0.0001) # 优化器
for epoch in range(100): # 开始训练
    for batch_idx, (input_batch, target_batch) in enumerate(chat_dataloader):       
        optimizer.zero_grad()  # 梯度清零        
        input_batch, target_batch = input_batch.to(device), target_batch.to(device) #移动到设备               
        outputs = model(input_batch) # 前向传播,计算模型输出       
        loss = criterion(outputs.view(-1, len(vocab)), target_batch.view(-1)) # 计算损失           
        loss.backward() # 反向传播        
        optimizer.step() # 更新参数    
    if (epoch + 1) % 20 == 0: # 每200个epoch打印一次损失值
        print(f"Epoch: {epoch + 1:04d}, cost = {loss:.6f}")

python">def generate_text_beam_search(model, input_str, max_len=50, beam_width=5):
    model.eval()  # 将模型设置为评估(测试)模式,关闭dropout和batch normalization等训练相关的层
    # 将输入字符串中的每个token 转换为其在词汇表中的索引
    input_tokens = [vocab[token] for token in input_str]
    # 创建一个列表,用于存储候选序列
    candidates = [(input_tokens, 0.0)]
    with torch.no_grad():  # 禁用梯度计算,以节省内存并加速测试过程
        for _ in range(max_len):  # 生成最多max_len个tokens
            new_candidates = []
            for candidate, candidate_score in candidates:
                inputs = torch.LongTensor(candidate).unsqueeze(0).to(device)
                outputs = model(inputs) # 输出 logits形状为[1, len(output_tokens), vocab_size]
                logits = outputs[:, -1, :] # 只关心最后一个时间步(即最新生成的token)的logits
                # 将<pad>标记的得分设置为一个很大的负数,以避免选择它
                logits[0, vocab["<pad>"]] = -1e9 # 不是这个原因,注意不认识的词汇都变成0
                # 找到具有最高分数的前beam_width个tokens
                scores, next_tokens = torch.topk(logits, beam_width, dim=-1)
                final_results = []
                for score, next_token in zip(scores.squeeze(), next_tokens.squeeze()):
                    new_candidate = candidate + [next_token.item()]
                    new_score = candidate_score - score.item()  # 使用负数,因为我们需要降序排列
                    if next_token.item() == vocab["<eos>"]:
                        # 如果生成的token是EOS(结束符),将其添加到最终结果中
                        final_results.append((new_candidate, new_score))
                    else:
                        # 将新生成的候选序列添加到新候选列表中
                        new_candidates.append((new_candidate, new_score))
            # 从新候选列表中选择得分最高的beam_width个序列
            candidates = sorted(new_candidates, key=lambda x: x[1])[:beam_width]
    # 选择得分最高的候选序列
    best_candidate, _ = sorted(candidates, key=lambda x: x[1])[0]
    # 将输出 token 转换回文本字符串
    output_str = " ".join([vocab.get_itos()[token] for token in best_candidate])
    return output_str
python">input_str = "what is the weather like today ?"
input_str = "hi , how are you ?"
# input_str = "hi , what is you name ?"

generated_text = generate_text_beam_search(model, input_str.split())
print("Generated text:", generated_text)

 

 

 

 

 


http://www.niftyadmin.cn/n/5410720.html

相关文章

24计算机考研调剂 | 山东科技大学

山东科技大学王瑞副教授团队&#xff08;先进成形技术与装备&#xff09;硕士招生 考研调剂招生信息 学校:- 专业:工学->机械工程->机械制造及其自动化 工学->材料科学与工程->材料加工工程 工学->治金工程->钢铁冶金 工学->计算机科学与技术 年级:20…

二分查找算法:高效搜索有序数据的利器

二分查找算法&#xff1a;高效搜索有序数据的利器 在计算机科学中&#xff0c;搜索是一项基本而重要的操作。对于有序数据&#xff0c;二分查找算法是一种高效的搜索方法。本文将介绍二分查找算法的原理、实现以及其在实际应用中的优势&#xff0c;帮助读者理解和应用这一常用的…

从租完ecs云服务器 使用docker建立用户 全过程

一 登录root用户 ssh root公网ip 输入密码&#xff0c;若没有密码可以前往阿里云设置服务器root密码 二 创建新用户 并赋予 新用户sudo权限 adduser $USER usermod -aG sudo $USER 三 Ubuntu安装docker sudo apt-get remove docker docker-engine docker.io containerd ru…

机器视觉相关网站推荐

机器视觉相关网站推荐 - 知乎计算机视觉&#xff08;Computer Vision&#xff09;最近这几年发展迅速&#xff0c;技术论坛和QQ群也如雨后春笋&#xff0c;特别是人脸检测/识别领域。近段时间云从君从研究院那里探得了好几个论坛&#xff0c;收藏了若干网站&#xff0c;记录下&…

C#面:Application , Cookie 和 Session 会话有什么不同

Application、Cookie 和 Session 是在Web开发中常用的三种会话管理方式 Application&#xff08;应用程序&#xff09;&#xff1a; Application 是在服务器端保存数据的一种方式&#xff0c;它可以在整个应用程序的生命周期内共享数据。Application 对象是在应用程序启动时创…

分布式事务(SeataClient)

问题场景 元数据 库存 100订单记录为空 下单操作 AutowiredRestTemplate restTemplate;/*** 下单** return*/Transactional // 开启事务 异常后触发数据库回滚操作Overridepublic Order create(Order order) {// 插入订单orderMapper.insert(order);// 扣减库存 MultiValueM…

MySQL进阶之(四)InnoDB数据存储结构之行格式

四、InnoDB数据存储结构之行格式 4.1 行格式的语法4.2 COMPACT 行格式4.2.1 记录的额外信息01、变长字段长度列表02、NULL 值列表03、记录头信息 4.2.2 记录的真实数据 4.3 Dynamic 和 Compressed 行格式4.3.1 字段的长度限制4.3.2 行溢出4.3.3 Dynamic 和 Compressed 行格式 4…

大数据开发(Java面试真题-卷一)

大数据开发&#xff08;Java面试真题&#xff09; 1、讲一下 JVM 的垃圾回收的相关概念&#xff1f;2、JVM 常见调优方法有哪些&#xff1f;3、请解释下Java中的线程池是什么&#xff0c;如何使用线程池来提高程序的性能&#xff1f;4、请解释Java中HashMap和HashTable的区别&a…