Shaders / Material - WESL

Back to examples View in GitHub

A shader that uses the WESL shading language.

use bevy::{
    mesh::MeshVertexBufferLayoutRef,
    pbr::{MaterialPipeline, MaterialPipelineKey},
    prelude::*,
    reflect::TypePath,
    render::render_resource::{
        AsBindGroup, RenderPipelineDescriptor, SpecializedMeshPipelineError,
    },
    shader::{ShaderDefVal, ShaderRef},
};

/// This example uses shader source files from the assets subdirectory
const FRAGMENT_SHADER_ASSET_PATH: &str = "shaders/custom_material.wesl";

fn main() {
    App::new()
        .add_plugins((
            DefaultPlugins,
            MaterialPlugin::<CustomMaterial>::default(),
            CustomMaterialPlugin,
        ))
        .add_systems(Startup, setup)
        .add_systems(Update, update)
        .run();
}

/// A plugin that loads the custom material shader
pub struct CustomMaterialPlugin;

/// An example utility shader that is used by the custom material
#[expect(
    dead_code,
    reason = "used to kept a strong handle, shader is referenced by the material"
)]
#[derive(Resource)]
struct UtilityShader(Handle<Shader>);

impl Plugin for CustomMaterialPlugin {
    fn build(&self, app: &mut App) {
        let handle = app
            .world_mut()
            .resource_mut::<AssetServer>()
            .load::<Shader>("shaders/util.wesl");
        app.insert_resource(UtilityShader(handle));
    }
}

/// set up a simple 3D scene
fn setup(
    mut commands: Commands,
    mut meshes: ResMut<Assets<Mesh>>,
    mut materials: ResMut<Assets<CustomMaterial>>,
) {
    // cube
    commands.spawn((
        Mesh3d(meshes.add(Cuboid::default())),
        MeshMaterial3d(materials.add(CustomMaterial {
            time: Vec4::ZERO,
            party_mode: false,
        })),
        Transform::from_xyz(0.0, 0.5, 0.0),
    ));

    // camera
    commands.spawn((
        Camera3d::default(),
        Transform::from_xyz(-2.0, 2.5, 5.0).looking_at(Vec3::ZERO, Vec3::Y),
    ));
}

fn update(
    time: Res<Time>,
    mut query: Query<(&MeshMaterial3d<CustomMaterial>, &mut Transform)>,
    mut materials: ResMut<Assets<CustomMaterial>>,
    keys: Res<ButtonInput<KeyCode>>,
) {
    for (material, mut transform) in query.iter_mut() {
        let material = materials.get_mut(material).unwrap();
        material.time.x = time.elapsed_secs();
        if keys.just_pressed(KeyCode::Space) {
            material.party_mode = !material.party_mode;
        }

        if material.party_mode {
            transform.rotate(Quat::from_rotation_y(0.005));
        }
    }
}

// This is the struct that will be passed to your shader
#[derive(Asset, TypePath, AsBindGroup, Clone)]
#[bind_group_data(CustomMaterialKey)]
struct CustomMaterial {
    // Needed for 16 bit alignment in WebGL2
    #[uniform(0)]
    time: Vec4,
    party_mode: bool,
}

#[repr(C)]
#[derive(Eq, PartialEq, Hash, Copy, Clone)]
struct CustomMaterialKey {
    party_mode: bool,
}

impl From<&CustomMaterial> for CustomMaterialKey {
    fn from(material: &CustomMaterial) -> Self {
        Self {
            party_mode: material.party_mode,
        }
    }
}

impl Material for CustomMaterial {
    fn fragment_shader() -> ShaderRef {
        FRAGMENT_SHADER_ASSET_PATH.into()
    }

    fn specialize(
        _pipeline: &MaterialPipeline,
        descriptor: &mut RenderPipelineDescriptor,
        _layout: &MeshVertexBufferLayoutRef,
        key: MaterialPipelineKey<Self>,
    ) -> Result<(), SpecializedMeshPipelineError> {
        let fragment = descriptor.fragment.as_mut().unwrap();
        fragment.shader_defs.push(ShaderDefVal::Bool(
            "PARTY_MODE".to_string(),
            key.bind_group_data.party_mode,
        ));
        Ok(())
    }
}
import super::util::make_polka_dots;

struct VertexOutput {
    @builtin(position) position: vec4<f32>,
    @location(2) uv: vec2<f32>,
}

struct CustomMaterial {
    // Needed for 16-bit alignment on WebGL2
    time: vec4<f32>,
}

@group(3) @binding(0) var<uniform> material: CustomMaterial;

@fragment
fn fragment(
    mesh: VertexOutput,
) -> @location(0) vec4<f32> {
    return make_polka_dots(mesh.uv, material.time.x);
}
fn make_polka_dots(pos: vec2<f32>, time: f32) -> vec4<f32> {
    let scaled_pos = pos * 6.0;
    let cell = vec2<f32>(fract(scaled_pos.x), fract(scaled_pos.y));
    var dist_from_center = distance(cell, vec2<f32>(0.5));

    let is_even = (floor(scaled_pos.x) + floor(scaled_pos.y)) % 2.0;

    var dot_color = vec3<f32>(0.0);
    var is_dot = 0.0;

    @if(!PARTY_MODE) {
        let color1 = vec3<f32>(1.0, 0.4, 0.8);  // pink
        let color2 = vec3<f32>(0.6, 0.2, 1.0);  // purple
        dot_color = mix(color1, color2, is_even);
        is_dot = step(dist_from_center, 0.3);
    } @else {
        let grid_x = floor(scaled_pos.x);
        let grid_y = floor(scaled_pos.y);
        let wave_speed = 3.0;
        let wave_phase = time * wave_speed;

        let diagonal_pos = (grid_x + grid_y) * 0.5;
        let wave_value = sin(diagonal_pos + wave_phase);

        let wave_normalized = (wave_value + 1.0) * 0.5;

        let color1 = vec3<f32>(1.0, 0.3, 0.7);
        let color2 = vec3<f32>(0.5, 0.1, 1.0);
        let intense_color1 = vec3<f32>(1.0, 0.1, 0.9);
        let intense_color2 = vec3<f32>(0.8, 0.0, 1.0);

        let animated_color1 = mix(color1, intense_color1, wave_normalized);
        let animated_color2 = mix(color2, intense_color2, wave_normalized);

        dot_color = mix(animated_color1, animated_color2, is_even);

        let size_mod = 0.15 * wave_value;
        dist_from_center = dist_from_center * (1.0 - size_mod);
        // Animate whether something is a dot by position but also time
        is_dot = step(dist_from_center, 0.3 + wave_normalized * 0.2);
    }

    return vec4<f32>(dot_color * is_dot, 1.0);
}