
新手也能学会的高性能gpu开发,只需要rust和cubel基础知识即可实现,让你的程序简单的使用gpu加速!
CubeCL是一个现代化的Rust GPU计算框架,它让编写高性能、可移植的GPU内核变得简单。通过CubeCL,你可以:
提示:即使没有GPU编程经验,通过本教程你也能快速上手CubeCL!
配置CubeCL非常简单,只需在Cargo.toml中添加依赖:
[dependencies]
cubecl = { version = "0.4.0", features = ["wgpu","default","std"] }wgpu:使用WGPU后端(跨平台)std:启用标准库支持小贴士:开发时建议同时启用wgpu和cuda特性,这样可以灵活切换后端测试
[dependencies]
cubecl = { version = "0.4.0",features = ["wgpu","cuda","default","std"] }让我们从一个简单的GPU计算程序开始。虽然初看可能有些复杂,但我们会逐步解析每个部分。
use cubecl::prelude::*;
#[cube(launch)] // 标记为可启动的GPU内核入口函数
fn gelu_array<F: Float>(input: &Array<Line<F>>, output: &mut Array<Line<F>>) {
// ABSOLUTE_POS_X是自动生成的线程索引
if ABSOLUTE_POS_X < input.len() { // 边界检查
// 对每个元素应用gelu_scalar激活函数
output[ABSOLUTE_POS_X] = gelu_scalar(input[ABSOLUTE_POS_X]);
}
}
#[cube] // 标记为GPU函数
fn gelu_scalar<F: Float>(x: Line<F>) -> Line<F> {
add(x, x/2.0)
}
#[cube] // 标记为GPU函数
fn add<F:Float>(a:Line<F>, b:Line<F>) -> Line<F> {
a + b // 向量化加法
}
pub fn lanch_test<R: Runtime>(device: &R::Device) {
// 创建与GPU设备的连接
let client = R::client(device);
// 准备测试数据: 4096个5.0f32
let input = &[5f32;8];
let vectorization = 4; // 向量化宽度
// 分配GPU内存
let output_handle = client.empty(input.len() * core::mem::size_of::<f32>());
let input_handle = client.create(f32::as_bytes(input));
unsafe {
// 启动GPU内核
gelu_array::launch::<f32, R>(
&client, // GPU客户端
CubeCount::Static(1, 1, 1), // 1个Hyper-Cube
CubeDim::new(
2, // X维度
1, // Y维度
1 // Z维度
),
ArrayArg::from_raw_parts::<f32>(&input_handle, input.len(), vectorization as u8),
ArrayArg::from_raw_parts::<f32>(&output_handle, input.len(), vectorization as u8),
)
};
// 读取结果
let bytes = client.read_one(output_handle.binding());
let output = f32::from_bytes(&bytes);
println!("GPU计算结果(Runtime: {:?}) => {:?}", R::name(), output);
}
fn main() {
type Runtime = cubecl::wgpu::WgpuRuntime;
let device = Default::default();
launch_test::<Runtime>(&device);
}虽然这段代码初看有些复杂,但CubeCL的设计实际上隐藏了许多GPU编程的复杂性。让我们分解理解每个部分:
Runtime是CubeCL的核心概念之一,它决定了你的代码将在哪种GPU后端上运行:
type Runtime = cubecl::wgpu::WgpuRuntime;在cubecl中,runtime代表了我们的gpu运算将基于什么去运行,这里我选择的是wgpu,同理,我们也可以将他换成为cuda。
Device代表实际的运算硬件。现代计算机可能有:
CubeCL支持灵活选择设备:
let device = WgpuDevice::Cpu;let device = WgpuDevice::DiscreteGpu(0); // 独立gpu 参数为显卡在系统中的序号let device = WgpuDevice::IntegratedGpu(0); // 集成gpu 参数为显卡在系统中的序号let device = WgpuDevice::VirtualGpu(0); // 虚拟gpu 参数为显卡在系统中的序号Client是连接CPU和GPU的桥梁,主要功能包括:
let client = R::client(device);client 是 GPU 运行时(如 CUDA 或 OpenCL)的高层抽象,封装了以下功能:
client.create() 和 client.empty())。client.read_one())。其中
R::client(device) 创建与指定GPU设备(device)绑定的运行时客户端。
client.create(data) 将CPU数据(data)拷贝到GPU显存,返回显存句柄(input_handle)。
client.empty(size) 在GPU显存中分配未初始化的空间(大小为size字节),返回句柄。
client.read_one(handle) 将GPU显存中的数据(通过handle标识)读回CPU内存。
gelu_array::launch(client, ...) 通过client提交内核执行任务到GPU队列。可以说client 是GPU编程的入口,它负责连接设备、管理数据、执行任务。
与传统GPU编程不同,CubeCL允许直接用Rust编写运算逻辑。关键点是#[cube]宏:
#[cube]
fn add<F:Float>(a:Line<F>, b:Line<F>) -> Line<F> {
a + b // 这个加法运算将在GPU上执行!
}Line<T>类型表示可向量化数据#[cube(launch)]
fn gelu_array<F: Float>(input: &Array<Line<F>>, output: &mut Array<Line<F>>) {
if ABSOLUTE_POS< input.len() {
output[ABSOLUTE_POS] = gelu_scalar(input[ABSOLUTE_POS_X]);
}
}
#[cube]
fn gelu_scalar<F: Float>(x: Line<F>) -> Line<F> {
// Execute the sqrt function at comptime.
add(x,x/2.0)
}
#[cube]
fn add<F:Float>(a:Line<F>, b:Line<F>) -> Line<F> {
a+b
}#[cube]宏详解#[cube]宏标记的函数将在GPU上执行,支持多种变体:
宏变体 | 用途 |
|---|---|
#[cube] | 基本GPU函数 |
#[cube(launch)] | 生成可启动的内核入口函数 |
#[cube(debug)] | 调试模式,打印生成代码 |
在这个“入口”函数中,我们是如下定义的
#[cube(launch)]
fn gelu_array<F: Float>(input: &Array<Line<F>>, output: &mut Array<Line<F>>) 官方是这么描述line的,我们的数据就从这里进入。
/// A [Line] represents a contiguous series of elements where SIMD operations may be available.
/// The runtime will automatically use SIMD instructions when possible for improved performance.正常情况下,使用launch去创建一个“入口”,之后我们就可以调用 函数名::launch了,如下
gelu_array::launch::<f32, R>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new(2u32, 1u32, 1u32),
ArrayArg::from_raw_parts::<f32>(&input_handle, input.len(), vectorization as u8),
ArrayArg::from_raw_parts::<f32>(&output_handle, input.len(), vectorization as u8),
);我们发现这里比我们设定的参数多了好多东西,最开始所传入的就是我们所生成的client,也就是调用gpu的那个接口,接下来的东西比较多,我们需要慢慢解释。
CubeCL使用独特的"多维立方体"模型管理GPU并行:
CubeCount::Static定义数量CubeDim::new定义尺寸这种抽象让并行计算更直观,Hyper-Cube又由Cube构成:

Cube由unit构成

最小单元就是unit每个unit中可以存放的数据量就是vectorization,当然这个也可以理解为一个uint的线程数。

gelu_array::launch::<f32, R>(
&client,
CubeCount::Static(1, 1, 1), // 1个Hyper-Cube
CubeDim::new(2, 1, 1), // 每个Cube是2x1x1
input_arg, // 输入数据
output_arg, // 输出缓冲区
);这个配置表示:
合理配置可以最大化GPU利用率:
至于为什么要计算这个,我们还要先回到运算逻辑代码处,我们先说明这段代码的作用
#[cube(launch)]
fn gelu_array<F: Float>(input: &Array<Line<F>>, output: &mut Array<Line<F>>) {
if ABSOLUTE_POS< input.len() {
output[ABSOLUTE_POS] = gelu_scalar(input[ABSOLUTE_POS]);
}
}这里虽然并没有出现for循环,但我们可以理解再这个函数内部是不断循环的,其中ABSOLUTE_POS就是不断遍历的下标,去执行我们定义的函数。同理下面这个是只遍历x轴上的unit,将每个x轴上的unit进行gelu_scalar函数处理,进行单列运算
#[cube(launch)]
fn gelu_array<F: Float>(input: &Array<Line<F>>, output: &mut Array<Line<F>>) {
if ABSOLUTE_POS_X < input.len() {
output[ABSOLUTE_POS_X] = gelu_scalar(input[ABSOLUTE_POS_X]);
}
}循环展开(Loop Unrolling) 是一种通过减少循环控制开销(如分支判断、计数器更新)来提升性能的优化技术。它通过将循环体内的代码重复多次,减少循环迭代次数,从而提高指令级并行性(ILP)和内存访问效率。以下是 CUDA 循环展开的详细解释和实现方法:
#[cube(launch_unchecked)]
fn sum_basic<F: Float>(input: &Array<F>, output: &mut Array<F>, #[comptime] end: Option<u32>) {
let unroll = end.is_some();
let end = end.unwrap_or_else(|| input.len());
let mut sum = F::new(0.0);
#[unroll(unroll)]
for i in 0..end {
sum += input[i];
}
output[UNIT_POS] = sum;
}处理矩阵需要同时考虑X和Y维度:
#[cube(launch)]
fn gelu_array<F: Float>(input: &Array<Line<F>>, output: &mut Array<Line<F>>) {
if ABSOLUTE_POS_X < input.len() {
if ABSOLUTE_POS_Y < output.len() {
output[ABSOLUTE_POS_X+ABSOLUTE_POS_Y *【x每行unit的个数】] = gelu_scalar(input[ABSOLUTE_POS_X+ABSOLUTE_POS_Y *【x每行unit的个数】]);
}
}
}关键点:
ABSOLUTE_POS_X和ABSOLUTE_POS_Y访问二维索引当然,仅仅是处理一个参数在世界情况中还是并不常见,那么多个参数的情况也很简单,只需要在队伍的lauch中添加参数即可
#[cube(launch)]
fn gelu_array<F: Float>(input1: &Array<Line<F>>,input2:&Array<Line<F>>,output: &mut Array<Line<F>>) {
if ABSOLUTE_POS_X < input1.len() {
if ABSOLUTE_POS_Y < output.len() {
output[ABSOLUTE_POS_X+ABSOLUTE_POS_Y*2] = add(input1[ABSOLUTE_POS_X+ABSOLUTE_POS_Y*2],input2[ABSOLUTE_POS_X+ABSOLUTE_POS_Y*2]);
}
}
}gelu_array::launch::<f32, R>(
&client,
CubeCount::Static(2, 1, 1),
CubeDim::new(2u32, 2u32, 1u32),
ArrayArg::from_raw_parts::<f32>(&input_handle, input.len(), vectorization as u8),
ArrayArg::from_raw_parts::<f32>(&input_handle, input.len(), vectorization as u8),
ArrayArg::from_raw_parts::<f32>(&output_handle, input.len(), vectorization as u8),
)cpu处理
use image::{ImageBuffer, Rgb, RgbImage};
fn main() {
// 1. 读取图像
let img = image::open("./../image.jpg").unwrap();
// 2. 转换为 RGB 格式(确保处理颜色通道)
let rgb_img = img.to_rgb8();
// 3. 创建可变的 ImageBuffer
let mut buffer: RgbImage = ImageBuffer::new(rgb_img.width(), rgb_img.height());
// 4. 遍历像素并反转颜色
for (x, y, pixel) in rgb_img.enumerate_pixels() {
let inverted = Rgb([
255 - pixel[0], // 反转红色通道
255 - pixel[1], // 反转绿色通道
255 - pixel[2], // 反转蓝色通道
]);
buffer.put_pixel(x, y, inverted);
}
// 5. 保存处理后的图像
buffer.save("./../output.jpg").unwrap();
}gpu处理,这里可以看到cubecl使用的并不只有泛型,还可以使用u32等类型去运算,但必需保证运算的类型是可以被支持的,否则会出现以下错误
U8 is not a valid WgpuElement实现
extern crate core;
use core::u32;
use cubecl::prelude::Float;
use std::ops::Sub;
use cubecl::prelude::*;
use cubecl::wgpu::WgpuDevice;
use image::{ImageBuffer, Rgb};
#[cube(launch)]
fn gelu_array(input1: &Array<Line<u32>>,output: &mut Array<Line<u32>>) {
if ABSOLUTE_POS < input1.len() {
output[ABSOLUTE_POS] = gelu_scalar(input1[ABSOLUTE_POS])
}
}
#[cube] // 标记为GPU函数
fn gelu_scalar(x: Line<u32>) -> Line<u32> {
minus(x)
}
#[cube]
fn minus(a:Line<u32>) -> Line<u32> {
Line::new(u32::max_value())-a
}
pub fn launch<R: Runtime>(device: &R::Device,list:&[u8],w:u32,h:u32) {
let client = R::client(device);
let vectorization = 4;
let output_handle = client.empty(list.len() * core::mem::size_of::<u8>());
let input_handle = client.create(list);
unsafe {
gelu_array::launch::<R>(
&client,
CubeCount::Static(8000, 1, 1),
CubeDim::new(100u32, 10u32, 1u32),
ArrayArg::from_raw_parts::<u32>(&input_handle, list.len(), vectorization as u8),
ArrayArg::from_raw_parts::<u32>(&output_handle, list.len(), vectorization as u8),
)
};
let bytes = client.read_one(output_handle.binding());
let bytes2 = client.read_one(input_handle.binding());
// 把RGB Vec<u8>转化为图片
let b: ImageBuffer<Rgb<u8>, Vec<u8>> = ImageBuffer::from_raw(w,h,bytes).unwrap();
b.save("./../output.png").unwrap();
}
fn main() {
type Runtime = cubecl::wgpu::WgpuRuntime;
let device = WgpuDevice::default();
let img = image::open("./../image.jpg").unwrap();
// 转换为 RGB Vec<u8>格式(确保处理颜色通道)
let rgb_img = img.to_rgb8();
let (w,h) = rgb_img.dimensions();
let buf = rgb_img.into_raw();
launch::<Runtime>(&device,&buf,w,h);
}原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。