Loading [MathJax]/jax/output/CommonHTML/config.js
前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >SwiftUI-MLX本地大模型开发(四)

SwiftUI-MLX本地大模型开发(四)

作者头像
YungFan
发布于 2025-04-21 01:27:50
发布于 2025-04-21 01:27:50
10900
代码可运行
举报
文章被收录于专栏:学海无涯学海无涯
运行总次数:0
代码可运行

介绍

SwiftUI-MLX本地大模型开发SwiftUI-MLX本地大模型开发(二)SwiftUI-MLX本地大模型开发(三)中,我们解决了基本使用、定制模型、使用本地模型、更改模型存储路径、转换模型、iPad运行等问题,但使用的都是别人训练好的模型。本文将介绍,如何基于一个通用 LLM 进行微调,使该模型成为个人的“专属”模型。

环境

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
pip install mlx
pip install mlx-lm
pip install transformers

数据

  • 新建文件train.jsonlvalid.jsonltest.jsonl,分别用于训练、验证与测试。
  • 根据模型说明文件,准备数据,格式如下。本文以ticoAg/Chinese-medical-dialogue进行微调。
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
{"text": "你是谁?\n我是你的私人智能小助手,我叫羊羊。"}
{"text": "地球有多大?\n地球的半径大约是6371公里。"}
...

微调

  • 微调时可以指定--model--data--adapter-path等参数。
  • 成功之后,会在adapters目录下生成多个.safetensors文件。
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
mlx_lm.lora \
 --train \
 --model /Users/yangfan/Documents/huggingface/models/mlx-community/Llama-3.2-1B-Instruct-4bit \
 --adapter-path /Users/yangfan/Desktop/adapters \
 --data /Users/yangfan/Desktop/Data \
 --batch-size 1 \

合并

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
mlx_lm.fuse \
 --model /Users/yangfan/Documents/huggingface/models/mlx-community/Llama-3.2-1B-Instruct-4bit \
 --adapter-path  /Users/yangfan/Desktop/adapters \
 --save-path /Users/yangfan/Desktop/Llama-3.2-1B-Instruct-4bit-fused # 新模型目录

代码

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import MLXLLM
import MLXLMCommon
import SwiftUI

// MARK: - 注册自定义模型
extension MLXLLM.ModelRegistry {
    public static let localModel = ModelConfiguration(
        directory: URL(fileURLWithPath: "/Users/yangfan/Desktop/Llama-3.2-1B-Instruct-4bit-fused"),
        overrideTokenizer: "PreTrainedTokenizer",
        defaultPrompt: ""
    )
}


struct ContentView: View {
    // 提示词
    @State private var prompt: String = "小孩扁桃体炎总哭饮食注意点是什么?"
    // 输出结果
    @State private var response: String = ""
    @State private var isLoading: Bool = false

    var body: some View {
        VStack(spacing: 16) {
            // 顶部输入区域
            HStack {
                TextField("输入提示词...", text: $prompt)
                    .textFieldStyle(.roundedBorder)
                    .font(.system(size: 16))

                Button {
                    response = ""

                    Task {
                        do {
                            try await generate()
                        } catch {
                            debugPrint(error)
                        }
                    }
                } label: {
                    Text("生成")
                        .foregroundStyle(.white)
                        .padding(.horizontal, 16)
                        .padding(.vertical, 8)
                        .background(prompt.isEmpty ? Color.gray : Color.blue)
                        .cornerRadius(8)
                }
                .buttonStyle(.borderless)
                .disabled(prompt.isEmpty || isLoading)
            }
            .padding(.horizontal)
            .padding(.top)

            // 分隔线
            Rectangle()
                .fill(Color.gray.opacity(0.2))
                .frame(height: 1)

            // 响应展示区域
            if response != "" {
                ResponseBubble(text: response)
            }

            Spacer()
        }

        if isLoading {
            ProgressView()
                .progressViewStyle(.circular)
                .padding()
        }
    }
}

extension ContentView {
    // MARK: 文本生成
    func generate() async throws {
        isLoading = true
        // 加载模型
        let modelConfiguration = ModelRegistry.localModel
        let modelContainer = try await LLMModelFactory.shared.loadContainer(configuration: modelConfiguration) { progress in
            print("正在下载 \(modelConfiguration.name),当前进度 \(Int(progress.fractionCompleted * 100))%")
        }
        // 生成结果
        let _ = try await modelContainer.perform { [prompt] context in
            let input = try await context.processor.prepare(input: .init(prompt: prompt))
            let result = try MLXLMCommon.generate(input: input, parameters: .init(), context: context) { tokens in
                let text = context.tokenizer.decode(tokens: tokens)
                Task { @MainActor in
                    self.response = text
                    self.isLoading = false
                }
                return .more
            }
            return result
        }
    }
}

struct ResponseBubble: View {
    let text: String

    var body: some View {
        ScrollView {
            VStack(alignment: .leading, spacing: 8) {
                Text("AI")
                    .font(.system(size: 16))
                    .foregroundColor(.gray)

                Text(text)
                    .font(.system(size: 16))
                    .lineSpacing(4)
                    .padding()
                    .background(Color.blue.opacity(0.1))
                    .cornerRadius(12)
            }
        }
        .padding(.horizontal)
    }
}

效果

  • 原始模型效果。

原始模型效果.gif

  • 微调模型效果。

微调模型.gif

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2025-04-20,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验