首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
社区首页 >问答首页 >基于DiffEqFlux的神经ODE反向传播过程中的标量运算

基于DiffEqFlux的神经ODE反向传播过程中的标量运算
EN

Stack Overflow用户
提问于 2019-12-27 18:45:14
回答 1查看 219关注 0票数 2

我正在使用Julia中的DiffEqFlux包来实现一个神经ODE。我在让它在GPU上工作时遇到了问题。

最简单的例子如下:

代码语言:javascript
代码运行次数:0
运行
复制
using DifferentialEquations
using Flux, DiffEqFlux
using CuArrays
CuArrays.allowscalar(false)

x = Float32[2.; 0.]|>gpu
tspan = Float32.((0.0f0,25.0f0))
dudt = Chain(Dense(2,50,tanh),Dense(50,2))|>gpu

loss() = sum(neural_ode(dudt,x,tspan,Tsit5(),save_everystep=false,save_start=false))

@show(loss())
Flux.back!(loss())

和stacktrace:

代码语言:javascript
代码运行次数:0
运行
复制
loss() = -24.529072f0 (tracked)
ERROR: LoadError: scalar getindex is disallowed
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] assertscalar(::String) at /home/vshanka2/.julia/packages/GPUArrays/1wgPO/src/indexing.jl:14
 [3] getindex at /home/vshanka2/.julia/packages/GPUArrays/1wgPO/src/indexing.jl:54 [inlined]
 [4] getindex at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.1/LinearAlgebra/src/adjtrans.jl:129 [inlined]
 [5] _unsafe_getindex_rs at ./reshapedarray.jl:245 [inlined]
 [6] _unsafe_getindex at ./reshapedarray.jl:242 [inlined]
 [7] getindex at ./reshapedarray.jl:231 [inlined]
 [8] macro expansion at ./multidimensional.jl:671 [inlined]
 [9] macro expansion at ./cartesian.jl:64 [inlined]
 [10] macro expansion at ./multidimensional.jl:666 [inlined]
 [11] _unsafe_getindex! at ./multidimensional.jl:662 [inlined]
 [12] _unsafe_getindex(::IndexLinear, ::Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1,Nothing}},Tuple{}}, ::UnitRange{Int64}) at ./multidimensional.jl:656
 [13] _getindex at ./multidimensional.jl:642 [inlined]
 [14] getindex(::Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1,Nothing}},Tuple{}}, ::UnitRange{Int64}) at ./abstractarray.jl:927
 [15] (::getfield(Tracker, Symbol("##429#432")){Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1,Nothing}},Tuple{}}})(::TrackedArray{…,CuArray{Float32,1,CuArray{Float32,2,Nothing}}}) at /home/vshanka2/.julia/packages/Tracker/cpxco/src/lib/array.jl:196
 [16] iterate at ./generator.jl:47 [inlined]
 [17] collect(::Base.Generator{Tuple{TrackedArray{…,CuArray{Float32,1,CuArray{Float32,2,Nothing}}},TrackedArray{…,CuArray{Float32,1,Nothing}},TrackedArray{…,CuArray{Float32,1,CuArray{Float32,2,Nothing}}},TrackedArray{…,CuArray{Float32,1,Nothing}}},getfield(Tracker, Symbol("##429#432")){Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1,Nothing}},Tuple{}}}}) at ./array.jl:606
 [18] #428 at /home/vshanka2/.julia/packages/Tracker/cpxco/src/lib/array.jl:193 [inlined]
 [19] back_(::Tracker.Call{getfield(Tracker, Symbol("##428#431")){Tuple{TrackedArray{…,CuArray{Float32,1,CuArray{Float32,2,Nothing}}},TrackedArray{…,CuArray{Float32,1,Nothing}},TrackedArray{…,CuArray{Float32,1,CuArray{Float32,2,Nothing}}},TrackedArray{…,CuArray{Float32,1,Nothing}}}},Tuple{Tracker.Tracked{CuArray{Float32,1,CuArray{Float32,2,Nothing}}},Tracker.Tracked{CuArray{Float32,1,Nothing}},Tracker.Tracked{CuArray{Float32,1,CuArray{Float32,2,Nothing}}},Tracker.Tracked{CuArray{Float32,1,Nothing}}}}, ::Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1,Nothing}},Tuple{}}, ::Bool) at /home/vshanka2/.julia/packages/Tracker/cpxco/src/back.jl:35
 [20] back(::Tracker.Tracked{CuArray{Float32,1,Nothing}}, ::Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1,Nothing}},Tuple{}}, ::Bool) at /home/vshanka2/.julia/packages/Tracker/cpxco/src/back.jl:58
 [21] (::getfield(Tracker, Symbol("##13#14")){Bool})(::Tracker.Tracked{CuArray{Float32,1,Nothing}}, ::Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1,Nothing}},Tuple{}}) at /home/vshanka2/.julia/packages/Tracker/cpxco/src/back.jl:38
 [22] foreach(::Function, ::Tuple{Tracker.Tracked{CuArray{Float32,1,Nothing}},Nothing,Nothing,Nothing}, ::Tuple{Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1,Nothing}},Tuple{}},CuArray{Float32,1,Nothing},Nothing,Nothing}) at ./abstractarray.jl:1867
 [23] back_(::Tracker.Call{getfield(DiffEqFlux, Symbol("##25#28")){DiffEqSensitivity.SensitivityAlg{0,true,Val{:central}},Base.Iterators.Pairs{Symbol,Bool,Tuple{Symbol},NamedTuple{(:save_everystep,),Tuple{Bool}}},TrackedArray{…,CuArray{Float32,1,Nothing}},CuArray{Float32,1,Nothing},Tuple{Tsit5},ODESolution{Float32,2,Array{CuArray{Float32,1,Nothing},1},Nothing,Nothing,Array{Float32,1},Array{Array{CuArray{Float32,1,Nothing},1},1},ODEProblem{CuArray{Float32,1,Nothing},Tuple{Float32,Float32},false,CuArray{Float32,1,Nothing},ODEFunction{false,getfield(DiffEqFlux, Symbol("#dudt_#32")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2,Nothing}},TrackedArray{…,CuArray{Float32,1,Nothing}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2,Nothing}},TrackedArray{…,CuArray{Float32,1,Nothing}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,getfield(DiffEqFlux, Symbol("#dudt_#32")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2,Nothing}},TrackedArray{…,CuArray{Float32,1,Nothing}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2,Nothing}},TrackedArray{…,CuArray{Float32,1,Nothing}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{CuArray{Float32,1,Nothing},1},Array{Float32,1},Array{Array{CuArray{Float32,1,Nothing},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats}},Tuple{Tracker.Tracked{CuArray{Float32,1,Nothing}},Nothing,Nothing,Nothing}}, ::CuArray{Float32,1,Nothing}, ::Bool) at /home/vshanka2/.julia/packages/Tracker/cpxco/src/back.jl:38
 [24] back(::Tracker.Tracked{CuArray{Float32,1,Nothing}}, ::CuArray{Float32,1,Nothing}, ::Bool) at /home/vshanka2/.julia/packages/Tracker/cpxco/src/back.jl:58
 [25] #13 at /home/vshanka2/.julia/packages/Tracker/cpxco/src/back.jl:38 [inlined]
 [26] foreach at ./abstractarray.jl:1867 [inlined]
 [27] back_(::Tracker.Call{getfield(Tracker, Symbol("##484#485")){TrackedArray{…,CuArray{Float32,1,Nothing}}},Tuple{Tracker.Tracked{CuArray{Float32,1,Nothing}}}}, ::Float32, ::Bool) at /home/vshanka2/.julia/packages/Tracker/cpxco/src/back.jl:38
 [28] back(::Tracker.Tracked{Float32}, ::Int64, ::Bool) at /home/vshanka2/.julia/packages/Tracker/cpxco/src/back.jl:58
 [29] #back!#15 at /home/vshanka2/.julia/packages/Tracker/cpxco/src/back.jl:77 [inlined]
 [30] #back! at ./none:0 [inlined]
 [31] #back!#32 at /home/vshanka2/.julia/packages/Tracker/cpxco/src/lib/real.jl:16 [inlined]
 [32] back!(::Tracker.TrackedReal{Float32}) at /home/vshanka2/.julia/packages/Tracker/cpxco/src/lib/real.jl:14
 [33] top-level scope at none:0
 [34] include at ./boot.jl:326 [inlined]
 [35] include_relative(::Module, ::String) at ./loading.jl:1038
 [36] include(::Module, ::String) at ./sysimg.jl:29
 [37] include(::String) at ./client.jl:403
 [38] top-level scope at none:0

正向传递可以很好地工作,但向后调用会导致标量getindex。如果我允许标量,当然没有问题,但它比cpu慢得多(即使是对于我正在处理的更大的问题)。

我不确定它是否与任何包依赖关系有关,但以下是我目前安装的。

代码语言:javascript
代码运行次数:0
运行
复制
  [fbb218c0] BSON v0.2.4
  [c5f51814] CUDAdrv v5.0.1
  [be33ccc6] CUDAnative v2.7.0
  [3a865a2d] CuArrays v1.6.0
  [aae7a2af] DiffEqFlux v0.7.0
  [9fdde737] DiffEqOperators v4.6.1
  [0c46a032] DifferentialEquations v6.9.0
  [587475ba] Flux v0.8.3
  [a98d9a8b] Interpolations v0.12.5
  [15e1cf62] NPZ v0.4.0
  [91a5bcdd] Plots v0.28.4

有什么想法吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-12-29 02:54:46

此问题已在DiffEqFlux 0.10.1上修复。执行]up以进行升级,或执行]add DiffEqFlux@0.10.1以明确要求此版本。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/59499585

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档