diff --git a/src/lib.rs b/src/lib.rs index 3097707..ed2438e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,7 +8,7 @@ pub use self::{ bind_group::{BindGroupBuilder, BindGroupLayoutBuilder}, buffer::{BufferExt, BulkBufferBuilder}, context::{Context, ContextBuilder}, - pipeline::RenderPipelineBuilder, + pipeline::{ComputePipelineBuilder, RenderPipelineBuilder}, texture::{Texture, TextureBuilder}, }; diff --git a/src/pipeline.rs b/src/pipeline.rs index bba9446..78eced8 100644 --- a/src/pipeline.rs +++ b/src/pipeline.rs @@ -106,3 +106,54 @@ impl<'a> RenderPipelineBuilder<'a> { }) } } + +pub struct ComputePipelineBuilder<'a> { + label: &'a str, + shader: &'a wgpu::ShaderModule, + layout_descriptor: Option>, +} + +impl<'a> ComputePipelineBuilder<'a> { + pub fn new(label: &'a str, shader: &'a wgpu::ShaderModule) -> Self { + Self { + label, + shader, + layout_descriptor: None, + } + } + + pub fn with_layout( + mut self, + label: &'a str, + bind_group_layouts: &'a [&wgpu::BindGroupLayout], + push_constant_ranges: &'a [wgpu::PushConstantRange], + ) -> Self { + let layout_descriptor = wgpu::PipelineLayoutDescriptor { + label: Some(label), + bind_group_layouts, + push_constant_ranges, + }; + + self.layout_descriptor = Some(layout_descriptor); + self + } + + pub fn build(self, context: &Context) -> wgpu::ComputePipeline { + let raw_layout; + let layout = if let Some(descriptor) = self.layout_descriptor { + raw_layout = context.device.create_pipeline_layout(&descriptor); + Some(&raw_layout) + } else { + None + }; + + context + .device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: Some(self.label), + layout, + module: self.shader, + entry_point: "compute", + }) + } +}