Metal.jl 1.6: Initial MPSGraph Support
Christian Guinard
Metal.jl adds initial support for MPSGraph, with the matrix multiplication functions wrapped, resolving some matrix multiplication issues in the previous method.
Initial MPSGraph support
PR #526 enabled the automatic generation of wrappers for all enum
s, struct
s, and Objective-C objects for the frameworks that Metal.jl relies upon. This made adding support for MPSGraph, Apple's MLIR gpu compiler interface, realistic.
To try out the new framework, constructor and method wrappers necessary for matrix multiplication were added, as well as linking it to the LinearAlgebra interface to work around the NaN issue that could show up on M1/M2 devices.
Lets go through a simple example doing pairwise multiplication followed by pairwise addition using MPSGraph directly:
using Metal, Random
using ObjectiveC: Foundation.NSDictionary
using Metal: encode!;using .MPS: MPSCommandBuffer
using .MPSGraphs: MPSGraph, placeholderTensor, MPSGraphTensorData, MPSGraphTensor, multiplicationWithPrimaryTensor, additionWithPrimaryTensor
T = Float32;
a = Metal.rand(10);
b = Metal.rand(10);
c = Metal.rand(10);
# To compare with the MPSGraph equivalent
res = (a .* b) .+ c;
graph = MPSGraph() # Initialize the graph
# Create placeholder tensors to be used to compile our graph
placeA = placeholderTensor(graph, size(a), T)
placeB = placeholderTensor(graph, size(b), T)
placeC = placeholderTensor(graph, size(c), T)
# Link the placeholder tensors to the data via a Dict
feeds = Dict{MPSGraphTensor, MPSGraphTensorData}(
placeA => MPSGraphTensorData(a),
placeB => MPSGraphTensorData(b),
placeC => MPSGraphTensorData(c)
)
# Add multiplication to the graph
pwisemul = MPSGraphs.multiplicationWithPrimaryTensor(graph, placeA, placeB)
# Add addition to the graph
pwiseadd = MPSGraphs.additionWithPrimaryTensor(graph, pwisemul, placeC)
# Our output tensor will be our c MtlArray
resultdict = Dict{MPSGraphTensor, MPSGraphTensorData}(
pwiseadd => feeds[placeC]
)
# Encode and run the graph
cmdbuf = MPS.MPSCommandBuffer(Metal.global_queue(device()))
MPS.encode!(cmdbuf, graph, NSDictionary(feeds), NSDictionary(resultdict))
Metal.commit!(cmdbuf)
Metal.wait_completed(cmdbuf)
# The MPSGraph result is equal to the typical way of doing things.
@assert isapprox(res, c)
Clearly, for simple operations like the above example, it is a lot of extra boilerplate without much benefit, but for more complex operations, MPSGraph will optimize the graph and operations before running, reducing expensive kernel launches and remving unecessary operations.
Another exciting aspect of this new framework wrapper is that it is now easier to add functionality that has been long-requested. One can find MPSGraph functionality not yet in Metal.jl and write wrappers using the existing wrappers as a starting point. If anyone is interested in helping out, feel free to open a pull request or an issue on the Metal.jl repository, and we will do our best to help you get your code merged.
Minor Changes
Metal.jl 1.6 also includes several other useful updates:
As always, we encourage users to update to the latest version to benefit from these improvements and bug fixes. Check out the changelog for a full list of changes.