-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcontiguous.rs
More file actions
132 lines (115 loc) · 4.5 KB
/
contiguous.rs
File metadata and controls
132 lines (115 loc) · 4.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
use crate::shapes::{MatrixOrdering, ViewShape, ViewShapeBuffers};
use crate::tensor::GpuTensorView;
use slang_hal::backend::Backend;
use slang_hal::function::GpuFunction;
use slang_hal::{Shader, ShaderArgs};
#[derive(Shader)]
#[shader(module = "stensor::linalg::contiguous")]
/// Slang module for conversion from a non-contiguous tensor into a contiguous tensor.
pub struct Contiguous<B: Backend> {
/// Shader for copying a non-contiguous tensor into a row-major contiguous tensor.
pub contiguous_row_major: GpuFunction<B>,
/// Shader for copying a non-contiguous tensor into a column-major contiguous tensor.
pub contiguous_col_major: GpuFunction<B>,
}
#[derive(ShaderArgs)]
struct ContiguousArgs<'a, B: Backend> {
tensor: B::BufferSlice<'a, f32>,
out: B::BufferSlice<'a, f32>,
shape: &'a B::Buffer<ViewShape>,
}
impl<B: Backend> Contiguous<B> {
/// Launch the kernel that copies the content of a `tensor` with non-contiguous layout into
/// the contiguous tensor `out`.
pub fn launch<'a>(
&self,
backend: &B,
shapes: &mut ViewShapeBuffers<B>,
pass: &mut B::Pass,
out: impl Into<GpuTensorView<'a, f32, B>>,
tensor: impl Into<GpuTensorView<'a, f32, B>>,
) -> Result<(), B::Error> {
let out = out.into();
let tensor = tensor.into();
let tensor_shape = tensor.shape();
let out_shape = out.shape();
assert_eq!(tensor_shape.size, out_shape.size);
let function = match out.is_contiguous() {
Some(MatrixOrdering::ColumnMajor) => &self.contiguous_col_major,
Some(MatrixOrdering::RowMajor) => &self.contiguous_row_major,
None => panic!("Output tensor must be contiguous."),
};
shapes.insert(backend, tensor_shape)?;
let shape = shapes.get(tensor_shape).unwrap_or_else(|| unreachable!());
let args = ContiguousArgs {
tensor: tensor.buffer(),
out: out.buffer(),
shape,
};
function.launch_capped(backend, pass, &args, tensor_shape.len() as u32)
}
}
#[cfg(test)]
mod test {
use crate::shapes::ViewShapeBuffers;
use crate::tensor::GpuTensor;
use minislang::SlangCompiler;
use nalgebra::DMatrix;
use slang_hal::backend::WebGpu;
use slang_hal::backend::{Backend, Encoder};
use slang_hal::{BufferUsages, Shader};
use wgpu::{Features, Limits};
#[futures_test::test]
#[serial_test::serial]
#[cfg(feature = "cuda")]
async fn gpu_contiguous_cuda() {
let mut backend = slang_hal::cuda::Cuda::new().unwrap();
gpu_contiguous_generic(backend).await;
}
#[futures_test::test]
#[serial_test::serial]
async fn gpu_contiguous_webgpu() {
let backend = WebGpu::new(Features::default(), Limits::default())
.await
.unwrap();
gpu_contiguous_generic(backend).await;
}
async fn gpu_contiguous_generic(backend: impl Backend) {
let mut compiler = SlangCompiler::new(vec![]);
crate::register_shaders(&mut compiler);
let contiguous = super::Contiguous::from_backend(&backend, &compiler).unwrap();
let mut shapes = ViewShapeBuffers::new(&backend);
const NROWS: u32 = 256;
const NCOLS: u32 = 128;
let tensor = DMatrix::<f32>::new_random(NROWS as usize, NCOLS as usize);
let mut output = DMatrix::<f32>::new_random(NCOLS as usize, NROWS as usize);
let gpu_tensor = GpuTensor::matrix(&backend, &tensor, BufferUsages::STORAGE).unwrap();
let gpu_output = GpuTensor::matrix(
&backend,
&output,
BufferUsages::STORAGE | BufferUsages::COPY_SRC,
)
.unwrap();
let mut encoder = backend.begin_encoding();
let mut pass = encoder.begin_pass();
contiguous
.launch(
&backend,
&mut shapes,
&mut pass,
&gpu_output,
gpu_tensor.as_view().transposed(),
)
.unwrap();
drop(pass); // Ensure the pass is ended before the encoder is borrowed again.
backend.submit(encoder).unwrap();
backend.synchronize().unwrap();
backend
.slow_read_buffer(gpu_output.buffer(), output.as_mut_slice())
.await
.unwrap();
// NOTE: don't use assert_relative_eq so it doesn't print out the whole matrices
// when it fails (it tends to break rustrover tests integration).
assert!(output == tensor.transpose());
}
}