use std::ops::Range;
use bevy::{
core_pipeline::core_3d::graph::{Core3d, Node3d},
ecs::{
query::QueryItem,
system::{lifetimeless::SRes, SystemParamItem},
},
math::FloatOrd,
pbr::{
DrawMesh, MeshInputUniform, MeshPipeline, MeshPipelineKey, MeshPipelineViewLayoutKey,
MeshUniform, RenderMeshInstances, SetMeshBindGroup, SetMeshViewBindGroup,
},
platform::collections::HashSet,
prelude::*,
render::{
batching::{
gpu_preprocessing::{
batch_and_prepare_sorted_render_phase, IndirectParametersCpuMetadata,
UntypedPhaseIndirectParametersBuffers,
},
GetBatchData, GetFullBatchData,
},
camera::ExtractedCamera,
extract_component::{ExtractComponent, ExtractComponentPlugin},
mesh::{allocator::MeshAllocator, MeshVertexBufferLayoutRef, RenderMesh},
render_asset::RenderAssets,
render_graph::{
NodeRunError, RenderGraphApp, RenderGraphContext, RenderLabel, ViewNode, ViewNodeRunner,
},
render_phase::{
sort_phase_system, AddRenderCommand, CachedRenderPipelinePhaseItem, DrawFunctionId,
DrawFunctions, PhaseItem, PhaseItemExtraIndex, SetItemPipeline, SortedPhaseItem,
SortedRenderPhasePlugin, ViewSortedRenderPhases,
},
render_resource::{
CachedRenderPipelineId, ColorTargetState, ColorWrites, Face, FragmentState, FrontFace,
MultisampleState, PipelineCache, PolygonMode, PrimitiveState, RenderPassDescriptor,
RenderPipelineDescriptor, SpecializedMeshPipeline, SpecializedMeshPipelineError,
SpecializedMeshPipelines, TextureFormat, VertexState,
},
renderer::RenderContext,
sync_world::MainEntity,
view::{ExtractedView, RenderVisibleEntities, RetainedViewEntity, ViewTarget},
Extract, Render, RenderApp, RenderDebugFlags, RenderSet,
},
};
use nonmax::NonMaxU32;
const SHADER_ASSET_PATH: &str = "shaders/custom_stencil.wgsl";
fn main() {
App::new()
.add_plugins((DefaultPlugins, MeshStencilPhasePlugin))
.add_systems(Startup, setup)
.run();
}
fn setup(
mut commands: Commands,
mut meshes: ResMut<Assets<Mesh>>,
mut materials: ResMut<Assets<StandardMaterial>>,
) {
commands.spawn((
Mesh3d(meshes.add(Circle::new(4.0))),
MeshMaterial3d(materials.add(Color::WHITE)),
Transform::from_rotation(Quat::from_rotation_x(-std::f32::consts::FRAC_PI_2)),
));
commands.spawn((
Mesh3d(meshes.add(Cuboid::new(1.0, 1.0, 1.0))),
MeshMaterial3d(materials.add(Color::srgb_u8(124, 144, 255))),
Transform::from_xyz(0.0, 0.5, 0.0),
DrawStencil,
));
commands.spawn((
PointLight {
shadows_enabled: true,
..default()
},
Transform::from_xyz(4.0, 8.0, 4.0),
));
commands.spawn((
Camera3d::default(),
Transform::from_xyz(-2.0, 4.5, 9.0).looking_at(Vec3::ZERO, Vec3::Y),
Msaa::Off,
));
}
#[derive(Component, ExtractComponent, Clone, Copy, Default)]
struct DrawStencil;
struct MeshStencilPhasePlugin;
impl Plugin for MeshStencilPhasePlugin {
fn build(&self, app: &mut App) {
app.add_plugins((
ExtractComponentPlugin::<DrawStencil>::default(),
SortedRenderPhasePlugin::<Stencil3d, MeshPipeline>::new(RenderDebugFlags::default()),
));
let Some(render_app) = app.get_sub_app_mut(RenderApp) else {
return;
};
render_app
.init_resource::<SpecializedMeshPipelines<StencilPipeline>>()
.init_resource::<DrawFunctions<Stencil3d>>()
.add_render_command::<Stencil3d, DrawMesh3dStencil>()
.init_resource::<ViewSortedRenderPhases<Stencil3d>>()
.add_systems(ExtractSchedule, extract_camera_phases)
.add_systems(
Render,
(
queue_custom_meshes.in_set(RenderSet::QueueMeshes),
sort_phase_system::<Stencil3d>.in_set(RenderSet::PhaseSort),
batch_and_prepare_sorted_render_phase::<Stencil3d, StencilPipeline>
.in_set(RenderSet::PrepareResources),
),
);
render_app
.add_render_graph_node::<ViewNodeRunner<CustomDrawNode>>(Core3d, CustomDrawPassLabel)
.add_render_graph_edges(Core3d, (Node3d::MainOpaquePass, CustomDrawPassLabel));
}
fn finish(&self, app: &mut App) {
let Some(render_app) = app.get_sub_app_mut(RenderApp) else {
return;
};
render_app.init_resource::<StencilPipeline>();
}
}
#[derive(Resource)]
struct StencilPipeline {
mesh_pipeline: MeshPipeline,
shader_handle: Handle<Shader>,
}
impl FromWorld for StencilPipeline {
fn from_world(world: &mut World) -> Self {
Self {
mesh_pipeline: MeshPipeline::from_world(world),
shader_handle: world.resource::<AssetServer>().load(SHADER_ASSET_PATH),
}
}
}
impl SpecializedMeshPipeline for StencilPipeline {
type Key = MeshPipelineKey;
fn specialize(
&self,
key: Self::Key,
layout: &MeshVertexBufferLayoutRef,
) -> Result<RenderPipelineDescriptor, SpecializedMeshPipelineError> {
let mut vertex_attributes = Vec::new();
if layout.0.contains(Mesh::ATTRIBUTE_POSITION) {
vertex_attributes.push(Mesh::ATTRIBUTE_POSITION.at_shader_location(0));
}
let vertex_buffer_layout = layout.0.get_layout(&vertex_attributes)?;
Ok(RenderPipelineDescriptor {
label: Some("Specialized Mesh Pipeline".into()),
layout: vec![
self.mesh_pipeline
.get_view_layout(MeshPipelineViewLayoutKey::from(key))
.clone(),
self.mesh_pipeline.mesh_layouts.model_only.clone(),
],
push_constant_ranges: vec![],
vertex: VertexState {
shader: self.shader_handle.clone(),
shader_defs: vec![],
entry_point: "vertex".into(),
buffers: vec![vertex_buffer_layout],
},
fragment: Some(FragmentState {
shader: self.shader_handle.clone(),
shader_defs: vec![],
entry_point: "fragment".into(),
targets: vec![Some(ColorTargetState {
format: TextureFormat::bevy_default(),
blend: None,
write_mask: ColorWrites::ALL,
})],
}),
primitive: PrimitiveState {
topology: key.primitive_topology(),
front_face: FrontFace::Ccw,
cull_mode: Some(Face::Back),
polygon_mode: PolygonMode::Fill,
..default()
},
depth_stencil: None,
multisample: MultisampleState::default(),
zero_initialize_workgroup_memory: false,
})
}
}
type DrawMesh3dStencil = (
SetItemPipeline,
SetMeshViewBindGroup<0>,
SetMeshBindGroup<1>,
DrawMesh,
);
struct Stencil3d {
pub sort_key: FloatOrd,
pub entity: (Entity, MainEntity),
pub pipeline: CachedRenderPipelineId,
pub draw_function: DrawFunctionId,
pub batch_range: Range<u32>,
pub extra_index: PhaseItemExtraIndex,
pub indexed: bool,
}
impl PhaseItem for Stencil3d {
#[inline]
fn entity(&self) -> Entity {
self.entity.0
}
#[inline]
fn main_entity(&self) -> MainEntity {
self.entity.1
}
#[inline]
fn draw_function(&self) -> DrawFunctionId {
self.draw_function
}
#[inline]
fn batch_range(&self) -> &Range<u32> {
&self.batch_range
}
#[inline]
fn batch_range_mut(&mut self) -> &mut Range<u32> {
&mut self.batch_range
}
#[inline]
fn extra_index(&self) -> PhaseItemExtraIndex {
self.extra_index.clone()
}
#[inline]
fn batch_range_and_extra_index_mut(&mut self) -> (&mut Range<u32>, &mut PhaseItemExtraIndex) {
(&mut self.batch_range, &mut self.extra_index)
}
}
impl SortedPhaseItem for Stencil3d {
type SortKey = FloatOrd;
#[inline]
fn sort_key(&self) -> Self::SortKey {
self.sort_key
}
#[inline]
fn sort(items: &mut [Self]) {
items.sort_by_key(SortedPhaseItem::sort_key);
}
#[inline]
fn indexed(&self) -> bool {
self.indexed
}
}
impl CachedRenderPipelinePhaseItem for Stencil3d {
#[inline]
fn cached_pipeline(&self) -> CachedRenderPipelineId {
self.pipeline
}
}
impl GetBatchData for StencilPipeline {
type Param = (
SRes<RenderMeshInstances>,
SRes<RenderAssets<RenderMesh>>,
SRes<MeshAllocator>,
);
type CompareData = AssetId<Mesh>;
type BufferData = MeshUniform;
fn get_batch_data(
(mesh_instances, _render_assets, mesh_allocator): &SystemParamItem<Self::Param>,
(_entity, main_entity): (Entity, MainEntity),
) -> Option<(Self::BufferData, Option<Self::CompareData>)> {
let RenderMeshInstances::CpuBuilding(ref mesh_instances) = **mesh_instances else {
error!(
"`get_batch_data` should never be called in GPU mesh uniform \
building mode"
);
return None;
};
let mesh_instance = mesh_instances.get(&main_entity)?;
let first_vertex_index =
match mesh_allocator.mesh_vertex_slice(&mesh_instance.mesh_asset_id) {
Some(mesh_vertex_slice) => mesh_vertex_slice.range.start,
None => 0,
};
let mesh_uniform = {
let mesh_transforms = &mesh_instance.transforms;
let (local_from_world_transpose_a, local_from_world_transpose_b) =
mesh_transforms.world_from_local.inverse_transpose_3x3();
MeshUniform {
world_from_local: mesh_transforms.world_from_local.to_transpose(),
previous_world_from_local: mesh_transforms.previous_world_from_local.to_transpose(),
lightmap_uv_rect: UVec2::ZERO,
local_from_world_transpose_a,
local_from_world_transpose_b,
flags: mesh_transforms.flags,
first_vertex_index,
current_skin_index: u32::MAX,
material_and_lightmap_bind_group_slot: 0,
tag: 0,
pad: 0,
}
};
Some((mesh_uniform, None))
}
}
impl GetFullBatchData for StencilPipeline {
type BufferInputData = MeshInputUniform;
fn get_index_and_compare_data(
(mesh_instances, _, _): &SystemParamItem<Self::Param>,
main_entity: MainEntity,
) -> Option<(NonMaxU32, Option<Self::CompareData>)> {
let RenderMeshInstances::GpuBuilding(ref mesh_instances) = **mesh_instances else {
error!(
"`get_index_and_compare_data` should never be called in CPU mesh uniform building \
mode"
);
return None;
};
let mesh_instance = mesh_instances.get(&main_entity)?;
Some((
mesh_instance.current_uniform_index,
mesh_instance
.should_batch()
.then_some(mesh_instance.mesh_asset_id),
))
}
fn get_binned_batch_data(
(mesh_instances, _render_assets, mesh_allocator): &SystemParamItem<Self::Param>,
main_entity: MainEntity,
) -> Option<Self::BufferData> {
let RenderMeshInstances::CpuBuilding(ref mesh_instances) = **mesh_instances else {
error!(
"`get_binned_batch_data` should never be called in GPU mesh uniform building mode"
);
return None;
};
let mesh_instance = mesh_instances.get(&main_entity)?;
let first_vertex_index =
match mesh_allocator.mesh_vertex_slice(&mesh_instance.mesh_asset_id) {
Some(mesh_vertex_slice) => mesh_vertex_slice.range.start,
None => 0,
};
Some(MeshUniform::new(
&mesh_instance.transforms,
first_vertex_index,
mesh_instance.material_bindings_index.slot,
None,
None,
None,
))
}
fn write_batch_indirect_parameters_metadata(
indexed: bool,
base_output_index: u32,
batch_set_index: Option<NonMaxU32>,
indirect_parameters_buffers: &mut UntypedPhaseIndirectParametersBuffers,
indirect_parameters_offset: u32,
) {
let indirect_parameters = IndirectParametersCpuMetadata {
base_output_index,
batch_set_index: match batch_set_index {
None => !0,
Some(batch_set_index) => u32::from(batch_set_index),
},
};
if indexed {
indirect_parameters_buffers
.indexed
.set(indirect_parameters_offset, indirect_parameters);
} else {
indirect_parameters_buffers
.non_indexed
.set(indirect_parameters_offset, indirect_parameters);
}
}
fn get_binned_index(
_param: &SystemParamItem<Self::Param>,
_query_item: MainEntity,
) -> Option<NonMaxU32> {
None
}
}
fn extract_camera_phases(
mut stencil_phases: ResMut<ViewSortedRenderPhases<Stencil3d>>,
cameras: Extract<Query<(Entity, &Camera), With<Camera3d>>>,
mut live_entities: Local<HashSet<RetainedViewEntity>>,
) {
live_entities.clear();
for (main_entity, camera) in &cameras {
if !camera.is_active {
continue;
}
let retained_view_entity = RetainedViewEntity::new(main_entity.into(), None, 0);
stencil_phases.insert_or_clear(retained_view_entity);
live_entities.insert(retained_view_entity);
}
stencil_phases.retain(|camera_entity, _| live_entities.contains(camera_entity));
}
fn queue_custom_meshes(
custom_draw_functions: Res<DrawFunctions<Stencil3d>>,
mut pipelines: ResMut<SpecializedMeshPipelines<StencilPipeline>>,
pipeline_cache: Res<PipelineCache>,
custom_draw_pipeline: Res<StencilPipeline>,
render_meshes: Res<RenderAssets<RenderMesh>>,
render_mesh_instances: Res<RenderMeshInstances>,
mut custom_render_phases: ResMut<ViewSortedRenderPhases<Stencil3d>>,
mut views: Query<(&ExtractedView, &RenderVisibleEntities, &Msaa)>,
has_marker: Query<(), With<DrawStencil>>,
) {
for (view, visible_entities, msaa) in &mut views {
let Some(custom_phase) = custom_render_phases.get_mut(&view.retained_view_entity) else {
continue;
};
let draw_custom = custom_draw_functions.read().id::<DrawMesh3dStencil>();
let view_key = MeshPipelineKey::from_msaa_samples(msaa.samples())
| MeshPipelineKey::from_hdr(view.hdr);
let rangefinder = view.rangefinder3d();
for (render_entity, visible_entity) in visible_entities.iter::<Mesh3d>() {
if has_marker.get(*render_entity).is_err() {
continue;
}
let Some(mesh_instance) = render_mesh_instances.render_mesh_queue_data(*visible_entity)
else {
continue;
};
let Some(mesh) = render_meshes.get(mesh_instance.mesh_asset_id) else {
continue;
};
let mut mesh_key = view_key;
mesh_key |= MeshPipelineKey::from_primitive_topology(mesh.primitive_topology());
let pipeline_id = pipelines.specialize(
&pipeline_cache,
&custom_draw_pipeline,
mesh_key,
&mesh.layout,
);
let pipeline_id = match pipeline_id {
Ok(id) => id,
Err(err) => {
error!("{}", err);
continue;
}
};
let distance = rangefinder.distance_translation(&mesh_instance.translation);
custom_phase.add(Stencil3d {
sort_key: FloatOrd(distance),
entity: (*render_entity, *visible_entity),
pipeline: pipeline_id,
draw_function: draw_custom,
batch_range: 0..1,
extra_index: PhaseItemExtraIndex::None,
indexed: mesh.indexed(),
});
}
}
}
#[derive(RenderLabel, Debug, Clone, Hash, PartialEq, Eq)]
struct CustomDrawPassLabel;
#[derive(Default)]
struct CustomDrawNode;
impl ViewNode for CustomDrawNode {
type ViewQuery = (
&'static ExtractedCamera,
&'static ExtractedView,
&'static ViewTarget,
);
fn run<'w>(
&self,
graph: &mut RenderGraphContext,
render_context: &mut RenderContext<'w>,
(camera, view, target): QueryItem<'w, Self::ViewQuery>,
world: &'w World,
) -> Result<(), NodeRunError> {
let Some(stencil_phases) = world.get_resource::<ViewSortedRenderPhases<Stencil3d>>() else {
return Ok(());
};
let view_entity = graph.view_entity();
let Some(stencil_phase) = stencil_phases.get(&view.retained_view_entity) else {
return Ok(());
};
let mut render_pass = render_context.begin_tracked_render_pass(RenderPassDescriptor {
label: Some("stencil pass"),
color_attachments: &[Some(target.get_color_attachment())],
depth_stencil_attachment: None,
timestamp_writes: None,
occlusion_query_set: None,
});
if let Some(viewport) = camera.viewport.as_ref() {
render_pass.set_camera_viewport(viewport);
}
if let Err(err) = stencil_phase.render(&mut render_pass, world, view_entity) {
error!("Error encountered while rendering the stencil phase {err:?}");
}
Ok(())
}
}