GitHub - FluxML/Zygote.jl: 21st century AD
] add Zygote
Zygote provides source-to-source automatic differentiation (AD) in Julia, and is the next-gen AD system for the Flux differentiable programming framework. For more details and benchmarks of Zygote's technique, see our paper. You may want to check out Flux for more interesting examples of Zygote usage; the documentation here focuses on internals and advanced AD usage.
Zygote supports Julia 1.10 onwards.
julia> using Zygote julia> f(x) = 5x + 3 julia> f(10), f'(10) (53, 5.0) julia> @code_llvm f'(10) define i64 @"julia_#625_38792"(i64) { top: ret i64 5 }
"Source-to-source" means that Zygote hooks into Julia's compiler, and generates the backwards pass for you โ as if you had written it by hand.
Zygote supports the flexibility and dynamism of the Julia language, including control flow, recursion, closures, structs, dictionaries, and more. Mutation and exception handling are currently not supported.
julia> fs = Dict("sin" => sin, "cos" => cos, "tan" => tan); julia> gradient(x -> fs[readline()](x), 1) sin 0.5403023058681398
Zygote benefits from using the ChainRules.jl ruleset.
Custom gradients can be defined by extending the ChainRulesCore.jl's rrule:
julia> using ChainRulesCore julia> add(a, b) = a + b julia> function ChainRulesCore.rrule(::typeof(add), a, b) add_pb(dy) = (NoTangent(), dy, dy) return add(a, b), add_pb end
To support large machine learning models with many parameters, Zygote can differentiate whole models with respect to their (possibly nested) structure of parameters, by passing them explicitly as arguments.
julia> using Zygote julia> model = (W = rand(2, 3), b = rand(2)); julia> predict(model, x) = model.W * x .+ model.b; julia> g = gradient(m -> sum(predict(m, [1, 2, 3])), model)[1] (W = [1.0 2.0 3.0; 1.0 2.0 3.0], b = [1.0, 1.0])
Warning
Zygote also has a legacy implicit-parameters interface, in which the parameters of
interest are collected in a Zygote.Params object and the gradients returned in a
dictionary-like Grads object. This interface is deprecated and will be removed in a
future release; use the explicit style shown above instead.

