DIY Virtual Functions in Burst
Now that we’ve seen how function pointers work and perform in Burst, let’s use them to build a higher-level feature: virtual functions!
Say we are making a game where players cast spells on each other. In this game, we have two spells: fireball and life steal. The fireball spell does damage to its target. The life steal spell does damage to its target and restores health to the caster. Both cost mana, which the player has in addition to health.
Now say we’d like to implement this with virtual functions. We’ll assume that the pros and cons of virtual functions has been weighed and the pros have won out. The virtual
keyword is unavailable to us because it can only be added to class
methods and classes aren’t allowed by Burst. C# simply won’t allow us to make a struct
method virtual
. The second issue is that C# doesn’t allow a struct to inherit from another struct. Without inheritance, virtual functions make no sense.
Now that Burst supports function pointers, we can work around both of these issues and essentially build our own virtual functions. First up, let’s build our own form of inheritance:
[StructLayout(LayoutKind.Sequential)] struct Spell { public int ManaCost; } [StructLayout(LayoutKind.Sequential)] struct Fireball { public Spell Base; public int Damage; } [StructLayout(LayoutKind.Sequential)] struct LifeSteal { public Spell Base; public int Damage; public int Healing; }
This isn’t inheritence as far as C# is concerned. The instanceof
operator will always evaluate to false
, casting isn’t supported, keywords like virtual
and override
will error, we can’t use the base
keyword, access specifiers like override
don’t apply, and so forth.
Instead, we have a “base class” in Spell
that can be “derived” or “inherited” by Fireball
and LifeSteal
. Just like actual class inheritance, this means Fireball
and LifeSteal
have all the fields and methods that their base class has. This is because they both include a Spell
as their first field. We can also perform upcasting. It just takes a little more work:
Fireball fireball = new Fireball(); // Upcast ref Spell spell = ref fireball.Base; // Access base class fields spell.ManaCost = 1;
Our game also has the Player
type:
struct Player { public int Health; public int Mana; }
We’re also going to need some utility functions. The reason for this will become clearer later on. First up, here’s a function to convert one type of ref
variable to another:
static class CastUtil { public static unsafe ref TDest RefToRef<TSrc, TDest>(in TSrc src) where TSrc : unmanaged where TDest : unmanaged { fixed (TSrc* pSrc = &src) { TDest* dest = (TDest*)pSrc; return ref *dest; } } }
We can use it like this:
Quaternion q; ref float4 f = ref CastUtil.RefToRef(q); f.w = 123; print(q.w); // 123
Next up, we have a non-generic version of FunctionPointer<T>
. This is useful if we ever need an unmanaged
version of the type.
public struct NonGenericFunctionPointer { [NativeDisableUnsafePtrRestriction] private readonly IntPtr ptr; public NonGenericFunctionPointer(IntPtr ptr) { this.ptr = ptr; } public FunctionPointer<T> Generic<T>() { return new FunctionPointer<T>(ptr); } }
Finally, we need a way to compile to NonGenericFunctionPointer
instead of FunctionPointer<T>
. This does it with a little bit of reflection:
static class BurstCompilerUtil<T> where T : class { private static readonly MethodInfo compileMethodInfo; static BurstCompilerUtil() { foreach (var mi in typeof(BurstCompiler).GetMethods( BindingFlags.Default | BindingFlags.Static | BindingFlags.NonPublic)) { if (mi.Name == "Compile") { compileMethodInfo = mi.MakeGenericMethod(typeof(T)); break; } } } public static unsafe NonGenericFunctionPointer CompileFunctionPointer(T del) { var obj = compileMethodInfo.Invoke(null, new object[] { del, true }); var ptr = Pointer.Unbox(obj); var intPtr = new IntPtr(ptr); return new NonGenericFunctionPointer(intPtr); } }
Now we can proceed to write our spell-casting functions. First, let’s define a delegate type to serve as the T
in FunctionPointer<T>
:
delegate void CastFunction( ref Spell thiz, ref Player caster, ref Player target);
Now let’s fill in the base Spell
type:
[BurstCompile] [StructLayout(LayoutKind.Sequential)] struct Spell { public int ManaCost; public NonGenericFunctionPointer Cast; private static readonly NonGenericFunctionPointer SpellCast = BurstCompilerUtil<CastFunction>.CompileFunctionPointer(DoCast); public Spell(int manaCost) { ManaCost = manaCost; Cast = SpellCast; } [BurstCompile] private static void DoCast( ref Spell thiz, ref Player caster, ref Player target) { BaseCast(ref thiz, ref caster, ref target); } public static void BaseCast( ref Spell thiz, ref Player caster, ref Player target) { caster.Mana -= thiz.ManaCost; } }
The BaseCast
function here does the basic work of spell-casting: deduct the mana cost from the caster. DoCast
is a non-[BurstCompile]
wrapper to work around a limitation in Burst where [BurstCompile]
functions can’t directly call each other. We compile DoCast
into a NonGenericFunctionPointer
with BurstCompilerUtil
and store it statically to avoid re-compiling over and over. In the constructor, this static field is set to an instance field.
Now let’s define Fireball
to “derive” from Spell
:
[BurstCompile] [StructLayout(LayoutKind.Sequential)] struct Fireball { public Spell Base; public int Damage; private static readonly NonGenericFunctionPointer FireballCast = BurstCompilerUtil<CastFunction>.CompileFunctionPointer(DoCast); public Fireball(int manaCost, int damage) { Base = new Spell(manaCost) { Cast = FireballCast }; Damage = damage; } [BurstCompile] private static void DoCast( ref Spell thiz, ref Player caster, ref Player target) { Spell.BaseCast(ref thiz, ref caster, ref target); ref var fireball = ref CastUtil.RefToRef<Spell, Fireball>(thiz); target.Health -= fireball.Damage; } }
The structure here is very similar. We have a static field to avoid re-compilation. We have a constructor that sets the instance field to that static field. In this case, Fireball
overwrites what Spell
set with its own function pointer.
Then we have a DoCast
that does the work of casting. Here we have the equivalent of base.Cast
call by starting with Spell.DoCast
. Then we use CastUtil
to convert the Spell
paramter that is basically this
to the type we know it really is: Fireball
. With that type in place, we can access its Damage
field to reduce the Health
of the Player
that was targeted. The lack of a public BaseCast
essentially makes this a sealed
class.
Finally, let’s implement LifeSteal
:
[BurstCompile] [StructLayout(LayoutKind.Sequential)] struct LifeSteal { public Spell Base; public int Damage; public int Healing; private static readonly NonGenericFunctionPointer LifeStealCast = BurstCompilerUtil<CastFunction>.CompileFunctionPointer(DoCast); public LifeSteal(int manaCost, int damage, int healing) { Base = new Spell(manaCost) { Cast = LifeStealCast }; Damage = damage; Healing = healing; } [BurstCompile] private static void DoCast( ref Spell thiz, ref Player caster, ref Player target) { Spell.BaseCast(ref thiz, ref caster, ref target); ref var lifeSteal = ref CastUtil.RefToRef<Spell, LifeSteal>(thiz); target.Health -= lifeSteal.Damage; caster.Health += lifeSteal.Healing; } }
This is nearly identical to Fireball
, except that DoCast
has different game logic because it also heals the caster.
Now that we have these “virtual” functions, let’s use them! Here’s a Burst-compiled job that casts a spell on many targets:
[BurstCompile] unsafe struct CastJob : IJob { [NativeDisableUnsafePtrRestriction] public Spell* Spell; public NativeArray<Player> Caster; public NativeArray<Player> Targets; public void Execute() { ref var spellRef = ref *Spell; var spell = spellRef.Cast.Generic<CastFunction>(); var caster = Caster[0]; for (int i = 0; i < Targets.Length; ++i) { var target = Targets[i]; spell.Invoke(ref spellRef, ref caster, ref target); Targets[i] = target; } Caster[0] = caster; } }
We have a Spell*
pointer for the spell to cast, the Caster
in a single-element NativeArray<Player>
, and the targets to cast the spell on. Execute
converts the Spell*
to a ref Spell
then gets the NonGenericFunctionPointer
field Cast
and uses Generic
to recover the strongly-typed FunctionPointer<CastFunction>
. It then loops over the targets invoking that function pointer on them.
Digression: The
Spell*
is the reason we needNonGenericFunctionPointer
andBurstCompilerUtil
. We can’t take a pointer to aFunctionPointer<T>
or any struct that contains one, includingSpell
. This is because C# considers the type to be “managed” due to being generic, regardless of its actual contents: a singleIntPtr
. By usingNonGenericFunctionPointer
instead, returned byBurstCompilerUtil
, we have a non-generic struct that passes the C# check. We could convertFunctionPointer<T>
toNonGenericFunctionPointer
instead of usingBurstCompilerUtil
, but theIntPtr
isprivate
and the type is “managed” so there’s no way to get it too other than reflection.
Let’s finish things off with a script to test out this functionality:
class TestScript : MonoBehaviour { [BurstCompile] static class Dummy { [BurstCompile] public static void DoCast( ref Spell thiz, ref Player caster, ref Player target) { } } public FunctionPointer<CastFunction> dummy; unsafe void Start() { dummy = BurstCompiler.CompileFunctionPointer<CastFunction>(Dummy.DoCast); var caster = new NativeArray<Player>(1, Allocator.TempJob); caster[0] = new Player { Health = 100, Mana = 10 }; var targets = new NativeArray<Player>(1, Allocator.TempJob); targets[0] = new Player { Health = 100, Mana = 10 }; void printState(string label) { print( $"{label}:n" + $"Caster: Health={caster[0].Health}, Mana={caster[0].Mana}n" + $"Target: Health={targets[0].Health}, Mana={targets[0].Mana}"); } printState("Start"); var fb = new Fireball(manaCost:1, damage:10); var job = new CastJob { Spell = &fb.Base, Caster = caster, Targets = targets }; job.Run(); printState("After fireball"); var ls = new LifeSteal(manaCost:1, damage:5, healing:5); job.Spell = &ls.Base; job.Run(); printState("After life steal"); caster.Dispose(); targets.Dispose(); } }
Note that the dummy
parts are only there to prevent BurstCompiler
being stripped from builds because we’re only using it via reflection.
All we’re doing here is making one player cast a fireball spell then a life steal spell at a target player and printing out their stats each time. Here’s what we see:
Start: Caster: Health=100, Mana=10 Target: Health=100, Mana=10 After fireball: Caster: Health=100, Mana=9 Target: Health=90, Mana=10 After life steal: Caster: Health=105, Mana=8 Target: Health=85, Mana=10
It works! But let’s confirm by digging a little deeper into the Burst Inspector to make sure it’s doing what we expect: (annotations by Jackson)
; Spell.DoCast mov eax, dword ptr [rdi] sub dword ptr [rsi + 4], eax ; Deduct mana ret ; Fireball.DoCast mov eax, dword ptr [rdi] sub dword ptr [rsi + 4], eax ; Deduct mana mov eax, dword ptr [rdi + 16] sub dword ptr [rdx], eax ; Apply damage ret ; LifeSteal.DoCast mov eax, dword ptr [rdi] sub dword ptr [rsi + 4], eax ; Deduct mana mov eax, dword ptr [rdi + 16] sub dword ptr [rdx], eax ; Apply damage mov eax, dword ptr [rdi + 20] add dword ptr [rsi], eax ; Restore health ret ; CastJob.Execute push rbp push r15 push r14 push r13 push r12 push rbx sub rsp, 24 mov r14, qword ptr [rdi] mov rax, qword ptr [rdi + 8] mov r13, qword ptr [r14 + 8] mov ecx, dword ptr [rax] mov edx, dword ptr [rax + 4] mov dword ptr [rsp + 8], ecx mov dword ptr [rsp + 12], edx cmp dword ptr [rdi + 72], 0 jle .LBB0_4 mov rbx, rdi xor ebp, ebp lea r15, [rsp + 8] lea r12, [rsp + 16] .LBB0_2: mov rax, qword ptr [rbx + 64] mov rax, qword ptr [rax + 8*rbp] mov qword ptr [rsp + 16], rax mov rdi, r14 mov rsi, r15 mov rdx, r12 call r13 ; Call "virtual" function mov rax, qword ptr [rsp + 16] mov rcx, qword ptr [rbx + 64] mov qword ptr [rcx + 8*rbp], rax inc rbp movsxd rax, dword ptr [rbx + 72] cmp rbp, rax jl .LBB0_2 mov ecx, dword ptr [rsp + 8] mov edx, dword ptr [rsp + 12] mov rax, qword ptr [rbx + 8] .LBB0_4: mov dword ptr [rax], ecx mov dword ptr [rax + 4], edx add rsp, 24 pop rbx pop r12 pop r13 pop r14 pop r15 pop rbp ret
The job is a bit long, but we can see it calling the “virtual” function. The others are really short and look pretty close to their C# counterparts, except that the Spell.BaseCast
call has been inlined.
There’s one important thing to remember with this technique: any base class pointers like Spell*
passed to jobs need to remain valid until the job completes. For example, if the pointer is to a local variable as in the example but the job completes after the function returns then the job will be accessing stack memory that may have been overwritten by future function calls. The result may will be a crash or data corruption. Be careful!
#1 by Neil Henning on March 23rd, 2020 ·
So it’s super cool that you are delving into the power user features of Burst – we’re stoked!
I wanted to point y’all to a section in the documentation that I wrote on how best to use function pointers within Burst – https://docs.unity3d.com/Packages/com.unity.burst@1.3/manual/index.html#performance-considerations
The TL;DR of the above is that you should really avoid using function pointers unless you really have to – simple jobs are best.
If you do have to use them then it is important that you batch up the function pointers so that you are processing N items within each function pointer, rather than calling a function pointer for each item, this will maximise the compilers ability to optimize your code!
#2 by jackson on March 23rd, 2020 ·
Thanks for writing up that section of the docs and for linking it here. I definitely agree with your advice. Looking forward to Burst 1.3!
#3 by Dan Miller on April 27th, 2021 ·
This is super interesting stuff! I wonder if it can be applied to a subset of lambdas generated by Linq’s Expression library
I have a problem which requires interpretation of mathematical and boolean expressions at runtime, so I’m using Expressions to compile those to lambdas. Problem is that those are not burst compatible at all, even though the only thing happening inside them is addition, sqrts, boolean ORs, etc. It would be very handy to turn those into runtime-compiled burst-compatible function pointers