Metal.jl 1.4: Improved random numbers


Christian Guinard

Metal.jl 1.4 adds higher-quality random number generators from the Metal Performance Shaders library. Some limitations apply, with a fallback to the current implementation in those situations.

Metal.rand and friends

Using functionality provided by the Metal Performance Shaders (MPS) library, Metal.jl now comes with much improved GPU random number generators. Uniform distributions using Metal.rand (and its in-place variant Metal.rand!) are available for all Metal-supported integer types and Float32. However, due to Metal API limitations, 8-bit and 16-bit integers may fall back to the lower-quality GPUArrays.jl random number generator if their size in bytes is not a multiple of 4. Normally distributed Float32 values can be generated for with Metal.randn and Metal.randn!, while Float16 is not supported by the MPS library and will always fall back to the GPUArrays implementation.

The easiest way to use these is to use the Metal convenience functions Metal.rand[n][!] as you would the usual functions from the Random.jl standard library:

julia> a = Metal.rand(Float32, 2)
2-element MtlVector{Float32, Metal.PrivateStorage}:
 0.95755994
 0.7110207

julia> Metal.randn!(a)
2-element MtlVector{Float32, Metal.PrivateStorage}:
 1.7230463
 0.55636907

However, the Random.jl methods can also be used by providing the appropriate RNG either from MPS.default_rng() or MPS.RNG() to the standard Random.rand[n][!] functions:

julia> using Random

julia> rng = MPS.RNG();

julia> Random.rand(rng, 2)
2-element MtlVector{Float32, Metal.PrivateStorage}:
 0.8941469
 0.67628527

Seeding is done by calling Metal.seed! for the global RNG, or Random.seed! when working with an explicit RNG object.

Other improvements since the last blog post

Future work

Although Metal.jl is now in v1, there is still work to be done to make it as fast and feature-complete as possible. In particular: