I think you want either of those [1]
s, but not both.
I’m surprised that Optimisers.update!
does not object more forcefully about the resulting mismatch of nested structures, but it does complain briefly. In fact, I think that warning is a path to warn you if you forgot [1]
entirely, and here it makes things worse, keeps just the nothing
?
julia> t = 0:3.3:10;
julia> force_data = randn(Float32, size(t)); # fake data, without running ODE code
julia> position_data = randn(Float32, size(t));
julia> gradient(loss, NNForce)[1]
(layers = (nothing, (weight = Float32[-0.00031716842; 0.005882237; … ; 0.02040809; -0.017797563;;], bias = Float32[-0.00015526835, 0.0045354646, 0.0001758635, 0.000911925, -0.00093348324, -0.0028021857, -0.00037156977, -0.0002479069, 0.00634047, 0.0022809654 … 0.00086749066, -0.015178025, -0.0015061237, 0.005903747, 0.0004341714, 0.011407629, 0.005850576, 0.0060465317, 0.019239068, 0.0012870021], σ = nothing), (weight = Float32[-0.014764998 0.015171923 … -0.022916555 0.0054369], bias = Float32[-0.0074519515], σ = nothing), nothing),)
julia> Flux.state(NNForce) # model's structure, must match the gradient!
(layers = ((), (weight = Float32[-0.28203613; 0.30942923; … ; -0.40744054; 0.09964981;;], bias = Float32[0.0046208426, -0.032450862, 0.06144477, 0.024944201, 0.057604324, 0.048263744, -0.009535714, -0.05626719, -0.03320607, -0.02485606 … -0.0106274625, 0.07197578, -0.042995136, -0.017393522, -0.021419354, -0.046221554, -0.06806464, -0.016840763, -0.06592702, 0.049628224], σ = ()), (weight = Float32[-0.007292727 0.16620831 … 0.41103598 -0.30187216], bias = Float32[-0.16961181], σ = ()), ()),)
julia> gradient(loss, NNForce)[1][1] # wrong!
(nothing, (weight = Float32[-0.00031716842; 0.005882237; … ; 0.02040809; -0.017797563;;], bias = Float32[-0.00015526835, 0.0045354646, 0.0001758635, 0.000911925, -0.00093348324, -0.0028021857, -0.00037156977, -0.0002479069, 0.00634047, 0.0022809654 … 0.00086749066, -0.015178025, -0.0015061237, 0.005903747, 0.0004341714, 0.011407629, 0.005850576, 0.0060465317, 0.019239068, 0.0012870021], σ = nothing), (weight = Float32[-0.014764998 0.015171923 … -0.022916555 0.0054369], bias = Float32[-0.0074519515], σ = nothing), nothing)
julia> for i in 1:10
∂loss∂m = gradient(loss, NNForce)[1]
Flux.Optimisers.update!(opt, NNForce, ∂loss∂m[1]) # as above, with [1][1]
println("loss: ", loss(NNForce))
end
┌ Warning: explicit `update!(opt, model, grad)` wants the gradient for the model alone,
│ not the whole tuple from `gradient(m -> loss(m, x, y), model)`. You probably want `grads[1]`.
└ @ Flux ~/.julia/packages/Flux/BkG8S/src/layers/basic.jl:87
loss: 2.9432225
┌ Warning: explicit `update!(opt, model, grad)` wants the gradient for the model alone,
│ not the whole tuple from `gradient(m -> loss(m, x, y), model)`. You probably want `grads[1]`.
└ @ Flux ~/.julia/packages/Flux/BkG8S/src/layers/basic.jl:87
loss: 2.9432225
...
loss: 2.9432225
loss: 2.9432225
loss: 2.9432225
loss: 2.9432225
loss: 2.9432225
loss: 2.9432225
loss: 2.9432225
loss: 2.9432225 # loss has not decreased
julia> for i in 1:10
∂loss∂m = gradient(loss, NNForce)[1]
Flux.Optimisers.update!(opt, NNForce, ∂loss∂m) # corrected
println("loss: ", loss(NNForce))
end
loss: 2.3963256
loss: 2.0603027
loss: 1.853425
loss: 1.7260386
loss: 1.6475848
loss: 1.599206
loss: 1.5692846
loss: 1.5506897
loss: 1.5390539
loss: 1.5317094 # now it learns?