(Compute) Shader Basics With Unity - N-Body Simulation

(EDIT: I unfortunately lost the images related to this article. Some of the meaning of the text may thus be lost. Apologies.)

In this small tutorial, we'll take a look at the basics of shader programming. We're using the Unity engine as our framework, though the tutorial is meant to be a very generically applicable look over the basics of shader programming. The reason we'll use Unity is that graphics programming involves a lot of (fairly boring) skeleton code, that Unity will handle for us. We'll mostly deal in HLSL (High-Level Shading Language) and DirectCompute on the shader side, which are part of Microsoft's DirectX suite.

You should not need to already be familiar with Unity to follow this tutorial. You should be familiar with the basics of programming. If you're looking for resources in learning the basic concepts of shader and graphics programming, this is for you!

The Goal

Our goal is to create a naive solution to the N-body problem. One common example of the problem would be predicting the positions of celestial objects. This is a problem to which no analytic solution exists; that is, there's no set of computationally feasible equations that would give an exact answer. Hence, the way to solve it is simulation.

The naive solution to the N-body problem is to take each object and calculate the total force of gravity affected upon them by all the other objects. This is a classic example of a problem with O(n^2) complexity in the big O notation.

The reason we picked the N-body problem to solve is that firstly, it's a very good problem to parallelize. Secondly, it's a relatively simple problem to solve naively so it wont distract us too much from understanding how our shader code works. Thirdly, it tends to end up looking kind of nice!

The fact that the problem is easy to parallelize makes it a good problem for the GPU to solve, which leads us to:

GPU, The Supreme Multitasker

The modern GPUs were developed to be able to process as many vertices and pixels as possible in as little time as possible in order to render ever more complex and realistic 3D scenes. This turned out to be something where extreme parallelization is very useful. Each vertex and pixel is processed on by similar code to every other vertex and pixel. Since the (static, non-animated) vertices and pixels do not affect each other, it's easy to run the code processing them in parallel.

Today, GPUs have up to thousands of cores in them, each able to execute code in parallel with all the other cores. GPUs and similar multiprocessing architectures are no longer used for just rendering, but also for complex data analysis, propagating neural networks, and lots of other tasks that can be parallelized well.

I'll stop here to keep this general description of GPUs short - a full introduction into GPUs and parallel computation would need an article (or a book.. or a book series?) of its own! If you want to learn more about the subject, I'll include some up-to-date resources at the end.

Starting Out With Our Problem

So what we want to do is to simulate a large amount of objects, all of them under the gravitational influence of each other. The simulation should run fully on the GPU, with the CPU acting as a mere coordinator.

We need to:

  • Represent positions, velocities and masses of stellar objects on the GPU;
  • Update the positions according to gravitational influences;
  • Render a representation of these objects on the screen.

We'll calculate the acceleration on each object and update the positions with a compute shader. These are special type of shaders meant for arbitrary computation. Then we'll tell the GPU to render N many quads, representing our stars, and use a vertex shader to set the positions of these quads on the screen. Finally, we'll use a fragment shader (also known as a pixel shader) to color these quads.


To calculate the gravitational forces, we'll write a compute shader. This is a type of shader meant for generic computation, as opposed to shader types meant for graphics processing and rendering.

Our whole shader is as follows:

// Each #kernel tells which function to compile; you can have many kernels
#pragma kernel CSMain

uniform RWStructuredBuffer position : register(u1);
uniform RWStructuredBuffer velocity : register(u2);

void CSMain (uint3 id : SV_DispatchThreadID)
	// Total force
	float3 t_force = float3(0.0f, 0.0f, 0.0f);

	// Iterate every "star".
	for (uint i = 0; i < 256 * 256; i++)
		float dist = distance(position[i], position[id.x]);
		// We'll pretend we're in a 2D universe where g = inverse distance rather than g = inverse distance^2
		// Otherwise the forces end up to such a range that we'll start having issues with floating point precision.
		// Plus, with the inverse square law, we'd have to be more careful in tuning the time and size scales etc - otherwise we end up with something extremely unstable.
		float g = 1.0 / dist * 0.0001;

		// Direction vector. We add 1e-10 to avoid NaN when length == 0
		float3 dir = position[i] - position[id.x];
		dir = dir / (length(dir) + 1e-10);

		float3 force;
		force = g * dir * 0.001;

		// Don't apply the force to total force if the object would be affecting itself.
		if (i != id.x)
			t_force += force;

	// Simplified Verlet/leapfrog integration
	position[id.x] += velocity[id.x] + t_force/2.0f;
	velocity[id.x] += t_force;

RWStructuredBuffer is a type of a data buffer that allows both reading and writing (hence - RW). This type of buffer can also contain custom defined structs. For us, though, two float3's is enough. register(u1) and register(u2) mean we're binding the buffers to (global) GPU registers. These type of registers keep their references between different shaders, so we can write to a buffer in one shader and read from the buffer in another shader without needing to involve the CPU with any data exchange.

With [numthreads(256,1,1)] we define how many threads are to be executed per thread group. When you're working with rendering, this is something that is typically decided by the drivers. But with compute shaders, a little aid is needed by the programmer as the driver can make fewer assumptions about the sort of calculations and output amounts we're wanting. The three numbers are merely for programmer's convenience to make it easier to define thread amounts for two- and three-dimensional data.

Later, in the C# code, we're setting the amount of thread groups over which we want our compute shader executed. So the shader is ran threads * thread groups times.

Back to the compute shader! The single argument - id - given to the kernel function is used to identify the current thread's number. Since we are working over one-dimensional data, id.x for us is in the range of 0 to threads * thread groups. If we were working with two-dimensional data, such as images, we might have [numthreads(16,16,1)] and we might use id.xy to identify the resource we're supposed to compute on.

description thread groups, thread and dispatch
Image from researchgate.net

Euler's Problem

On line 35 we update the velocity and position. Perhaps the simplest way to do it would be with the standard Euler method:

velocity = velocity + acceleration * timestep
position = position + velocity * timestep

The problem with this method is its instability. Unless timestep is very small (we don't use timestep at the moment just to simplify things a little), energy is not conserved and orbits and similar systems without damping tend to have their total energies increase. Here's the worst case scenario how our simulation might look like with the Euler method:

Energy continues to increase in the system until whole thing is flung off the screen.

To fix this, we can use leapfrog integration instead. The basic idea is to interleave the additions to acceleration, velocity and position, thus leading to a closer approximation of continuous increase rather than discrete increase to these properties.

Normally you would want to keep track of acceleration separately and have forces apply on the acceleration, but we'll be a little bit lazy and use a simplified version that doesn't need keep tracking of acceleration:

position = position + velocity * timestep + acceleration/2.0f * timestep
velocity = velocity + acceleration * timestep

The reason for the /2.0f is that we're increasing position before increasing velocity; hence, we're working with the velocity from the last frame. We make up for this by factoring acceleration in twice. Thus, we are "leaping".

A little on optimization..

GPU programming has some specific optimization concerns. The extreme parallelism is partially achieved by being able to execute the same code at the same time for multiple different sets of data. Thus branching is something that should be avoided at all costs, as it makes it difficult for the GPU to predict what code it should be executing in the future.

For example, if we change the code from this:

	for (uint i = 0; i < 256 * 256; i++)
		[stuff here]
		if (i != id.x)
			t_force += force;

To this:

	for (uint i = 0; i < 256 * 256; i++)
            if (i != id.x)

	     [stuff here]

We incur a massive performance penalty.

It's also worth it to try and limit the amount of operations being done. GPU is very effective in doing a small amount of brancheless code over a little bit of data at a time. Often it is more effective to split algorithms to multiple stages and run each separately than it is to make one large compute shader trying to compute everything at once.


This is the whole of our C# code for setting up and dispatching computation:

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using System;

public class NBodyCompute : MonoBehaviour {

    public ComputeShader shader;

    public static ComputeBuffer pos_buf;
    public static ComputeBuffer vel_buf;

    void Start () {
        pos_buf = new ComputeBuffer(256 * 256, 3 * sizeof(float), ComputeBufferType.Default);
        vel_buf = new ComputeBuffer(256 * 256, 3 * sizeof(float), ComputeBufferType.Default);

        // These global buffers apply to every shader with these buffers defined.
        Shader.SetGlobalBuffer(Shader.PropertyToID("position"), pos_buf);
        Shader.SetGlobalBuffer(Shader.PropertyToID("velocity"), vel_buf);

        float[] pos_data = new float[256 * 256 * sizeof(float)];
        float[] vel_data = new float[256 * 256 * sizeof(float)];

        for (int i = 0; i < 256 * 256; i++)
            float rand_x = (0.5f - (i%256)/256.0f) * 16.0f;
            float rand_y = (0.5f - (i/256)/256.0f) * 16.0f;

            pos_data[i * 3 + 0] = rand_x;
            pos_data[i * 3 + 1] = rand_y;
            // We are 2D so set the z axis to zero.
            pos_data[i * 3 + 2] = 0.0f;

            // If position was a vector from origin (0,0), this would turn it 90 degrees - e.g. create circular motion.
            vel_data[i * 3] = rand_y * 0.01f;
            vel_data[i * 3 + 1] = -rand_x * 0.01f;
            vel_data[i * 3 + 2] = 0.0f;

	// Update is called once per frame
	void Update () {
        shader.Dispatch(shader.FindKernel("CSMain"), 256, 1, 1);

    void OnDestroy()

Our public variable shader is so we can specify the compute shader we want to use via Unity's editor. pos_buf and vel_buf contain the buffers for positions and velocities of our objects.

On lines 18 and 19 we set position and velocity as global buffers.

Starting on line 24, we set some initial data to the buffers. In this case, we create a roughly uniform grid of star bodies and set them a "circular" velocity vector. You can try various kind of other initial parameters to see how it influences the simulation.

Here's how it looks like with uniform grid positioning and circular starting motion:


This is all the C# code to render our stars:

using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class StarCreatorRenderer : MonoBehaviour {
    public Material material;

    Matrix4x4[][] transformList;

    Mesh mesh;

    const int instance_max = 1023;
    const int wanted_instances = 256 * 256;

    const float star_size = 1.0f;

    // Use this for initialization
    void Start () {
        transformList = new Matrix4x4[wanted_instances / instance_max][];

        MeshFilter mf = GetComponent();

        mesh = new Mesh();
        mf.mesh = mesh;

        // Create a basic quad.
        Vector3[] vertices = new Vector3[4];

        vertices[0] = new Vector3(-star_size, -star_size, 0);
        vertices[1] = new Vector3(star_size, -star_size, 0);
        vertices[2] = new Vector3(-star_size, star_size, 0);
        vertices[3] = new Vector3(star_size, star_size, 0);

        mesh.vertices = vertices;

        int[] tri = new int[6];

        tri[0] = 0;
        tri[1] = 2;
        tri[2] = 1;

        tri[3] = 2;
        tri[4] = 3;
        tri[5] = 1;

        mesh.triangles = tri;

        // Only 1023 objects can be rendered as instanced.
        // So we split the 256*256 objects into sets of size 1023.
        for (int set = 0; set < wanted_instances / instance_max; set++)
            int instances = instance_max;
            if (set == (wanted_instances / instance_max) - 1)
                instances = wanted_instances % instance_max;

            transformList[set] = new Matrix4x4[instances];

            for (int i = 0; i < instances; i++)
                Matrix4x4 matrix = new Matrix4x4();
                matrix.SetTRS(Vector3.zero, Quaternion.Euler(Vector3.zero), Vector3.one);
                transformList[set][i] = matrix;
	// Update is called once per frame
	void Update () {
        for (int set = 0; set < wanted_instances / instance_max; set++)
            int instances = instance_max;
            if (set == (wanted_instances / instance_max) - 1)
                instances = wanted_instances % instance_max;

            MaterialPropertyBlock mpb = new MaterialPropertyBlock();
            mpb.SetInt("offset", set * instance_max);
            if (set < wanted_instances / instance_max / 2)
                mpb.SetColor("color", new Color(0.45f, 0.5f, 0.75f, 0.5f));
                mpb.SetColor("color", new Color(0.9f, 0.4f, 0.5f, 0.5f));
            Graphics.DrawMeshInstanced(mesh, 0, material, transformList[set], instances, mpb);

To render the stars, what we first need is a quad! The graphics pipeline prefers to deal with triangles, so we'll compose our quad of a pair of triangles.

quad vertices
Image from ogldev.atspace.co.uk

We could also provide normals and UV (texture) coordinates for our mesh of two triangles, but those are not necessary for our needs.

Once our mesh has been created (we set it to some associated Unity components, so we could test the quad by simply turning on the "Mesh Renderer" component on our Unity object holding the above code) we need to create a bunch of transform matrices. The reason we do this is that we render our stars with by using GPU instancing. This is a special method of rendering meshes where the GPU is told to repeat the rendering of a specific mesh several times over. This is a very effective way of rendering. In fact, the rendering of 256*256*2 triangles takes no more than a millisecond or so.

Either way, GPU instancing is a little bit special. You can not set variables for your shader like you normally would; instead you have to define arrays of parameters for each invocation. That's what the MaterialBlockBuffer is for.

We want to render 256*256 stars, but since we can only instance 1023 at a time, there's some extra fluff code involved in figuring out how many to render at a time.

Shading The Stars

Here's our shader code for rendering the stars:

Shader "Custom/Star"
		Tags {"Queue" = "Transparent" "RenderType" = "Transparent" }
		LOD 100

		Cull Off
		ZWrite Off
		Blend One One


			#pragma target 5.0

			#pragma vertex vert
			#pragma fragment frag

			#pragma multi_compile_instancing

			#include "UnityCG.cginc"

			uniform StructuredBuffer position : register(t1);

			uniform int offset;
			uniform float4 color = float4(0.45f, 0.5f, 0.75f, 0.5f);

			struct appdata
				float4 vertex : POSITION;

			struct v2f {
				float4 pos : SV_POSITION;
				float4 vertex : VERTEX;

			// Vertex shader
			v2f vert(appdata v)
				v2f o;

				o.pos = UnityObjectToClipPos(v.vertex);
				o.pos += float4(position[unity_InstanceID + offset], 0.0f);
				o.vertex = v.vertex;

				return o;

			// Pixel shader
			fixed4 frag(v2f i) : SV_Target
				float dist = distance(i.vertex, float4(0.0, 0.0, 0.0, 0.0));
				float multiplier = 0.1 / pow(dist, 100.0);

				if (multiplier > 1)
					multiplier = 1;

				return color * multiplier;

Here there's a little bit more Unity-specific code, but for the most part it's standard HLSL. We set culling off, we set writing to the depth buffer on, set some blending parameters and proceed to the vertex and fragment shaders.

These two types of shaders are what are most commonly used in rendering pipelines (there are a few additional shader types, but they aren't nearly as ubiquitous).

Vertex shader is invoked per each vertex.  Transformations to the vertices and thus the whole mesh are applied here.

Pixel shaders (also known as fragment shaders) are typically invoked for each pixel being rendered. This may be the pixels on the screen if rendering directly to the screen buffer or it might be the pixels of a target texture.

In the rendering pipeline, data is passed from the vertex shaders to the pixel shaders. Both can access uniform variables (variables that can be set from the CPU side) and any buffers (such as the StructuredBuffer on register(t1), which is set to the same buffer as with the compute shader due to the SetGlobalBuffer call).

description of the programmable graphics pipeline
Image from researchgate.net

UNITY_SETUP_INSTANCE_ID is needed to get the ID of the current instance. We use the uniform variable offset to figure accommodate the set system we created earlier due to the fact that we can only instance 1023 meshes at a time.

When passing data from vertex shader to pixel shader, the data is, by default, interpolated so that e.g. if your vertex on one end of your polygon has the value 1 for something and in the other end, it has the value 0, then in the middle of your polygon the fragment shader would receive the value 0.5.

Using the above for our advantage, we pass the vertex positions to the fragment shader. Then we get the inverse of distance to the middle of the quad mesh to create an effect commonly known as metaballs:

single metaball metaballs blending
Images from nullcandy.com

Metaballs is kind of a neat, cheap way to get some quick prettiness to pretty much anything.

Now we are able to render our stars under our make-do N-body simulation! We can play screw around with the settings a bit, change up the colors, star sizes, initial parameters, positions, velocities.. All kinds of cool stuff can be tried out, like, for example, disabling the clearing of the screen buffer:

Music sample from Carbon Based Lifeforms

That's it, then! Unity is free and if you want to dive into the code, you can find it from GitLab. Just open the NBodyScene scene file in Unity and press Play.

Feel free to drop a comment, I'm open to both criticism and praise. I'd also love to hear if there's something that was left a bit ambiguous or something you'd like me to explain more deeply in another article.

Here are some additional resources I quickly collected from around the web:

Unity fragment & vertex shader examples: https://docs.unity3d.com/Manual/SL-VertexFragmentShaderExamples.html

Introduction to GPU computation by Nvidia (this is a great read): http://www.int.washington.edu/PROGRAMS/12-2c/week3/clark_01.pdf

Introduction to shader programming starting from the very basics: https://thebookofshaders.com/00/

Introduction to GPU programming that seemed pretty decent, in Python: https://nyu-cds.github.io/python-gpu/01-introduction/

Some more about parallel computation (though more from the viewpoint of general programming on the CPU rather than the GPU): https://computing.llnl.gov/tutorials/parallel_comp/

Jalmari Ikävalko

Read more posts by this author.