首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >聊聊 从源码来看ChatGLM-6B的模型结构

聊聊 从源码来看ChatGLM-6B的模型结构

作者头像
Ryan_OVO
发布于 2024-01-07 02:49:56
发布于 2024-01-07 02:49:56
2.4K00
代码可运行
举报
文章被收录于专栏:程序随笔程序随笔
运行总次数:0
代码可运行

基于ChatGLM-6B第一版,要注意还有ChatGLM2-6B以及ChatGLM3-6B

概述

ChatGLM是transformer架构的神经网络模型,因此从transformer结构入手,分析其源码结构。 transformer结构:

转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/

位置编码

ChatGLM-6B的位置编码采用的旋转位置编码(RoPB)实现。其源码:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
class RotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
        super().__init__()
        inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
        inv_freq = inv_freq.half()
        self.learnable = learnable
        if learnable:
            self.inv_freq = torch.nn.Parameter(inv_freq)
            self.max_seq_len_cached = None
        else:
            self.register_buffer('inv_freq', inv_freq)
            self.max_seq_len_cached = None
            self.cos_cached = None
            self.sin_cached = None
        self.precision = precision

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
                              error_msgs):
        pass

    def forward(self, x, seq_dim=1, seq_len=None):
        if seq_len is None:
            seq_len = x.shape[seq_dim]
        if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
            self.max_seq_len_cached = None if self.learnable else seq_len
            t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
            freqs = torch.einsum('i,j->ij', t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            if self.precision == torch.bfloat16:
                emb = emb.float()

            # [sx, 1 (b * np), hn]
            cos_cached = emb.cos()[:, None, :]
            sin_cached = emb.sin()[:, None, :]
            if self.precision == torch.bfloat16:
                cos_cached = cos_cached.bfloat16()
                sin_cached = sin_cached.bfloat16()
            if self.learnable:
                return cos_cached, sin_cached
            self.cos_cached, self.sin_cached = cos_cached, sin_cached
        return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]

    def _apply(self, fn):
        if self.cos_cached is not None:
            self.cos_cached = fn(self.cos_cached)
        if self.sin_cached is not None:
            self.sin_cached = fn(self.sin_cached)
        return super()._apply(fn)

## 转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/

激活函数

ChatGLM-6B采用的激活函数是GeLU(高斯误差线性单元),其源码:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
@torch.jit.script
def gelu_impl(x):
    """OpenAI's gelu implementation."""
    return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
                                       (1.0 + 0.044715 * x * x)))


def gelu(x):
    return gelu_impl(x)

编码器-解码器(encoder-decoder)

接下来就是编码器解码器结构,如何抓住模型源头来分析?可以从transformers的API入手:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().to("cuda:1").eval()

print(mode)

## 转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/

输出:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
ChatGLMForConditionalGeneration(
  (transformer): ChatGLMModel(
    (word_embeddings): Embedding(130528, 4096)
    (layers): ModuleList(
      (0-27): 28 x GLMBlock(
        (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attention): SelfAttention(
          (rotary_emb): RotaryEmbedding()
          (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)
          (dense): Linear(in_features=4096, out_features=4096, bias=True)
        )
        (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (mlp): GLU(
          (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)
          (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)
        )
      )
    )
    (final_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=4096, out_features=130528, bias=False)
)

从脑图的角度来梳理下其结构

其结构图表示如下:

将结构图与最开始的transformer结构图对比来看,两者还是比较符合的。 官方源码中标注了编码器与解码器是一体的,只需要配置参数即可切换为解码器。如下:

转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2024-01-06,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
十分钟读懂旋转编码(RoPE)
旋转位置编码(Rotary Position Embedding,RoPE)是论文 Roformer: Enhanced Transformer With Rotray Position Embedding 提出的一种能够将相对位置信息依赖集成到 self-attention 中并提升 transformer 架构性能的位置编码方式。而目前很火的 LLaMA、GLM 模型也是采用该位置编码方式。
zenRRan
2023/09/11
6.6K0
十分钟读懂旋转编码(RoPE)
ChatGLM3 源码解析(三)
ApacheCN_飞龙
2024/03/08
4470
最强英文开源模型LLaMA架构探秘,从原理到源码
读完本文,你可能觉得LLaMA会开源并不令人惊讶,因为它的架构可以说是站在巨人肩膀上摘苹果——基本上可以说使用其他模型的组件作为“积木”搭了一个新模型出来,并没有太多实质意义上的创新,但这种敢于开源的勇气和做法使得LLaMA足以在大语言模型上的开源发展历程上成为一个标志性的里程碑。
Steve Wang
2023/10/23
4.5K0
最强英文开源模型LLaMA架构探秘,从原理到源码
聊聊ChatGLM-6B医疗数据微调
参考了多个医疗大模型,如扁鹊、灵心等,重新思考了下微调的方案以及数据集的格式;基于ChatGLM/其它LLM整合多种微调方法的非官方实现的框架,审视其数据集格式,以及调试效果,进行微调。 最终基于liucongg/ChatGLM-Finetuning开源框架成功的微调出来我想要的结果。
Ryan_OVO
2024/03/17
5370
聊聊ChatGLM-6B医疗数据微调
机器学习|从0开发大模型之模型预训练
继续写《从0开发大模型》系列文章,本文主要介绍预训练过程。 预训练是目的是让模型学习知识,需要将预处理的数据(《机器学习|从0开发大模型之数据预处理》)中生成的 pretrain_data.bin 文件的上下文全部学习到,那预训练怎么做呢?
用户1904552
2025/02/27
2520
机器学习|从0开发大模型之模型预训练
聊聊ChatGLM-6B的源码分析
作用:在微调时(以P-Tuning V2为例),方法训练时冻结模型的全部参数,只激活PrefixEncoder的参数。 其源码如下,整体来看是比较简单的。
Ryan_OVO
2024/01/09
7050
聊聊ChatGLM-6B的源码分析
Deepseek-V2技术报告解读!全网最细!
深度求索Deepseek近日发布了v2版本的模型,沿袭了1月发布的 Deepseek-MoE(混合专家模型)的技术路线,采用大量的小参数专家进行建模,同时在训练和推理上加入了更多的优化。沿袭了一贯的作风,Deepseek对模型(基座和对话对齐版本)进行了完全的mit协议开源,可以商用。对于算力不是那么充足的开发者,官方提供了API调用的方案,费用更是达到了全场最低!
zenRRan
2025/02/03
1.4K0
Deepseek-V2技术报告解读!全网最细!
聊聊ChatGLM-6B源码分析(二)
GLM模型中位置编码是2D的,有两层的位置表示,分别是序列的位置表示和mask block的位置表示。由get_position_ids函数处理。position_ids对应GLM论文中的postion 1,block_position_ids对应GLM论文中的position 2。
Ryan_OVO
2024/01/13
4730
聊聊ChatGLM-6B源码分析(二)
Llama也中招,混合精度下位置编码竟有大坑,百川智能给出修复方案
位置编码技术是一种能够让神经网络建模句子中 Token 位置信息的技术。在 Transformer 大行其道的时代,由于 Attention 结构无法建模每个 token 的位置信息,位置编码(Position embedding) 成为 Transformer 非常重要的一个组件。研究人员也提出了各种各样的位置编码方案来让网络建模位置信息,Rope 和 Alibi 是目前最被广泛采纳的两种位置编码方案。
机器之心
2023/09/08
7380
Llama也中招,混合精度下位置编码竟有大坑,百川智能给出修复方案
【论文复现】时序预测:多头注意力+宽度学习
Liyun Su, Lang Xiong和Jialing Yang在2024年发表了题为“Multi-Attn BLS: Multi-head attention mechanism with broad learning system for chaotic time series prediction”的论文,发表在《Applied Soft Computing》杂志上(CiteScore14.3,影响因子8.7)。这篇论文针对混沌时间序列数据的高复杂性和非线性提出了一种新的范式,即将宽度学习模型与多头自注意力机制相结合。在此之前,将这两种高度非线性映射算法融合的主要方法是使用堆叠的多头自注意力来提取特征,然后使用宽度学习模型进行分类预测。这篇论文提出了一种直接将多头注意力模块集成到宽度学习中的方法,从而实现了端到端的预测模型。
Eternity._
2024/11/28
3440
【论文复现】时序预测:多头注意力+宽度学习
Llama深入浅出
前方干货预警:这可能是你能够找到的最容易懂的最具实操性的学习开源LLM模型源码的教程。
lyhue1991
2023/09/05
2.5K1
Llama深入浅出
大模型部署框架 FastLLM 实现细节解析
以chatglm-6b的支持为例,函数入口在 https://github.com/ztxz16/fastllm/blob/master/src/models/chatglm.cpp#L626 ,这里的 input 就是输入的 context(string类型)。然后 https://github.com/ztxz16/fastllm/blob/master/src/models/chatglm.cpp#L633 这行代码对 input 进行 tokenizer encode并构造好inputIds,再构造好attentionMask之后就可以给Forward函数推理,拿到推理结果之后再使用tokenizer进行decode得到输出。
BBuf
2023/08/22
1.3K0
大模型部署框架 FastLLM 实现细节解析
聊聊大模型微调训练全流程的思考
参考现有的中文医疗模型:MedicalGPT、CareGPT等领域模型的训练流程,结合ChatGPT的训练流程,总结如下: 在预训练阶段,模型会从大量无标注文本数据集中学习领域/通用知识;其次使用{有监督微调}(SFT)优化模型以更好地遵守特定指令;最后使用对齐技术使LLM更有用更安全的响应用户的提示。
Ryan_OVO
2024/03/19
1.3K0
聊聊大模型微调训练全流程的思考
大模型部署框架 FastLLM 简要解析
本文主要是对FastLLM做了一个简要介绍,展示了一下FastLLM的部署效果。然后以chatglm-6b为例,对FastLLM模型导出的流程进行了解析,接着解析了chatglm-6b模型部分的核心实现。最后还对FastLLM涉及到的优化技巧进行了简单的介绍。
BBuf
2023/08/22
1K0
大模型部署框架 FastLLM 简要解析
【Pre-Training】Transformers 源码阅读和实践
本文主要针对HuggingFace开源的 transformers,以BERT为例介绍其源码并进行一些实践。主要以pytorch为例 (tf 2.0 代码风格几乎和pytorch一致),介绍BERT使用的Transformer Encoder,Pre-training Tasks和Fine-tuning Tasks。最后,针对预训练好的BERT进行简单的实践,例如产出语句embeddings,预测目标词以及进行抽取式问答。本文主要面向BERT新手,在阅读本文章前,假设读者已经阅读过BERT原论文。
阿泽 Crz
2020/11/25
2.7K0
ChatGLM3 源码解析(一)
ApacheCN_飞龙
2024/03/05
6160
聊聊ChatGLM中P-tuning v2的应用
论文PDF地址:https://arxiv.org/pdf/2110.07602.pdf
Ryan_OVO
2024/01/13
4320
聊聊ChatGLM中P-tuning v2的应用
alphaFold2 | 模型细节之特征提取(三)
文章链接: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC8371605/
机器学习炼丹术
2022/11/22
1.2K0
alphaFold2 | 模型细节之特征提取(三)
preprint版本 | 何凯明大神新作MAE | CVPR2022最佳论文候选
本文证明了蒙面自动编码器(MAE)是一种可扩展的计算机视觉自监督学习器。我们的MAE方法很简单:我们屏蔽输入图像的随机补丁并重建丢失的像素。
机器学习炼丹术
2021/12/06
1.3K0
preprint版本 | 何凯明大神新作MAE | CVPR2022最佳论文候选
【LLM系列之PaLM】PaLM: Scaling Language Modeling with Pathways
PaLM 在decoder-only架构中使用标准的 Transformer 模型架构(即每个时间步只能关注其自身和过去的时间步),并进行以下修改: (1)采用SwiGLU激活函数:用于 MLP 中间激活,因为与标准 ReLU、GELU 或 Swish 激活相比,《GLU Variants Improve Transformer》论文里提到:SwiGLU 已被证明可以显著提高模型效果。
致Great
2023/08/25
1K0
【LLM系列之PaLM】PaLM: Scaling Language Modeling with Pathways
相关推荐
十分钟读懂旋转编码(RoPE)
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验