diff --git a/src/conditional_layers/conditional_layer_glow.jl b/src/conditional_layers/conditional_layer_glow.jl index 2c3b895e..d7bfb500 100644 --- a/src/conditional_layers/conditional_layer_glow.jl +++ b/src/conditional_layers/conditional_layer_glow.jl @@ -127,14 +127,14 @@ function inverse(Y::AbstractArray{T, N}, C::AbstractArray{T, N}, L::ConditionalL X_ = tensor_cat(X1, X2) X = L.C.inverse(X_) - save == true ? (return X, X1, X2, Sm) : (return X) + save == true ? (return X, X1, X2, logS, Sm) : (return X) end # Backward pass: Input (ΔY, Y), Output (ΔX, X) function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, C::AbstractArray{T, N}, L::ConditionalLayerGlow;) where {T,N} # Recompute forward state - X, X1, X2, S = inverse(Y, C, L; save=true) + X, X1, X2, logS, S = inverse(Y, C, L; save=true) # Backpropagate residual ΔY1, ΔY2 = tensor_split(ΔY) @@ -147,7 +147,7 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, C::AbstractA end # Backpropagate RB - ΔX2_ΔC = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), (tensor_cat(X2, C))) + ΔX2_ΔC = L.RB.backward(tensor_cat(backward(ΔS, logS, S, L.activation), ΔT), (tensor_cat(X2, C))) ΔX2, ΔC = tensor_split(ΔX2_ΔC; split_index=size(ΔY2)[N-1]) ΔX2 += ΔY2 diff --git a/src/layers/invertible_layer_basic.jl b/src/layers/invertible_layer_basic.jl index 469e86af..6a6355e5 100644 --- a/src/layers/invertible_layer_basic.jl +++ b/src/layers/invertible_layer_basic.jl @@ -96,9 +96,9 @@ function forward(X1::AbstractArray{T, N}, X2::AbstractArray{T, N}, L::CouplingLa Y2 = S.*X2 + logS_T2 if logdet - save ? (return X1, Y2, coupling_logdet_forward(S), S) : (return X1, Y2, coupling_logdet_forward(S)) + save ? (return X1, Y2, coupling_logdet_forward(S), logS_T1, S) : (return X1, Y2, coupling_logdet_forward(S)) else - save ? (return X1, Y2, S) : (return X1, Y2) + save ? (return X1, Y2, logS_T1, S) : (return X1, Y2) end end @@ -112,9 +112,9 @@ function inverse(Y1::AbstractArray{T, N}, Y2::AbstractArray{T, N}, L::CouplingLa X2 = (Y2 - logS_T2) ./ (S .+ eps(T)) # add epsilon to avoid division by 0 if logdet - save == true ? (return Y1, X2, -coupling_logdet_forward(S), S) : (return Y1, X2, -coupling_logdet_forward(S)) + save == true ? (return Y1, X2, -coupling_logdet_forward(S), logS_T1, S) : (return Y1, X2, -coupling_logdet_forward(S)) else - save == true ? (return Y1, X2, S) : (return Y1, X2) + save == true ? (return Y1, X2, logS_T1, S) : (return Y1, X2) end end @@ -122,7 +122,7 @@ end function backward(ΔY1::AbstractArray{T, N}, ΔY2::AbstractArray{T, N}, Y1::AbstractArray{T, N}, Y2::AbstractArray{T, N}, L::CouplingLayerBasic; set_grad::Bool=true) where {T, N} # Recompute forward state - X1, X2, S = inverse(Y1, Y2, L; save=true, logdet=false) + X1, X2, logS_T1, S = inverse(Y1, Y2, L; save=true, logdet=false) # Backpropagate residual ΔT = copy(ΔY2) @@ -132,11 +132,11 @@ function backward(ΔY1::AbstractArray{T, N}, ΔY2::AbstractArray{T, N}, Y1::Abst end ΔX2 = ΔY2 .* S if set_grad - ΔX1 = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), X1) + ΔY1 + ΔX1 = L.RB.backward(tensor_cat(backward(ΔS, logS_T1, S, L.activation), ΔT), X1) + ΔY1 else - ΔX1, Δθ = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), X1; set_grad=set_grad) + ΔX1, Δθ = L.RB.backward(tensor_cat(backward(ΔS, logS_T1, S, L.activation), ΔT), X1; set_grad=set_grad) if L.logdet - _, ∇logdet = L.RB.backward(tensor_cat(L.activation.backward(coupling_logdet_backward(S), S), 0 .*ΔT), X1; set_grad=set_grad) + _, ∇logdet = L.RB.backward(tensor_cat(backward(coupling_logdet_backward(S), logS_T1, S, L.activation), 0 .*ΔT), X1; set_grad=set_grad) end ΔX1 += ΔY1 end @@ -152,7 +152,7 @@ end function backward_inv(ΔX1::AbstractArray{T, N}, ΔX2::AbstractArray{T, N}, X1::AbstractArray{T, N}, X2::AbstractArray{T, N}, L::CouplingLayerBasic; set_grad::Bool=true) where {T, N} # Recompute inverse state - Y1, Y2, S = forward(X1, X2, L; save=true, logdet=false) + Y1, Y2, logS_T1, S = forward(X1, X2, L; save=true, logdet=false) # Backpropagate residual ΔT = -ΔX2 ./ S @@ -161,9 +161,9 @@ function backward_inv(ΔX1::AbstractArray{T, N}, ΔX2::AbstractArray{T, N}, X1:: set_grad ? (ΔS += coupling_logdet_backward(S)) : (∇logdet = -coupling_logdet_backward(S)) end if set_grad - ΔY1 = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), Y1) + ΔX1 + ΔY1 = L.RB.backward(tensor_cat(backward(ΔS, logS_T1, S, L.activation), ΔT), Y1) + ΔX1 else - ΔY1, Δθ = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), Y1; set_grad=set_grad) + ΔY1, Δθ = L.RB.backward(tensor_cat(backward(ΔS, logS_T1, S, L.activation), ΔT), Y1; set_grad=set_grad) ΔY1 += ΔX1 end ΔY2 = - ΔT @@ -187,14 +187,14 @@ function jacobian(ΔX1::AbstractArray{T, N}, ΔX2::AbstractArray{T, N}, Δθ::Ab logS_T1, logS_T2 = tensor_split(L.RB.forward(X1)) ΔlogS_T1, ΔlogS_T2 = tensor_split(jacobian(ΔX1, Δθ, X1, L.RB)[1]) S = L.activation.forward(logS_T1) - ΔS = L.activation.backward(ΔlogS_T1, S) + ΔS = backward(ΔlogS_T1, logS_T1, S, L.activation) Y2 = S.*X2 + logS_T2 ΔY2 = ΔS.*X2 + S.*ΔX2 + ΔlogS_T2 if logdet # Gauss-Newton approximation of logdet terms JΔθ = tensor_split(L.RB.jacobian(zeros(Float32, size(ΔX1)), Δθ, X1)[1])[1] - GNΔθ = -L.RB.adjointJacobian(tensor_cat(L.activation.backward(JΔθ, S), zeros(Float32, size(S))), X1)[2] + GNΔθ = -L.RB.adjointJacobian(tensor_cat(backward(JΔθ, logS_T1, S, L.activation), zeros(Float32, size(S))), X1)[2] save ? (return ΔX1, ΔY2, X1, Y2, coupling_logdet_forward(S), GNΔθ, S) : (return ΔX1, ΔY2, X1, Y2, coupling_logdet_forward(S), GNΔθ) else diff --git a/src/layers/invertible_layer_glow.jl b/src/layers/invertible_layer_glow.jl index 7a2c70fd..57d206e1 100644 --- a/src/layers/invertible_layer_glow.jl +++ b/src/layers/invertible_layer_glow.jl @@ -129,14 +129,14 @@ function inverse(Y::AbstractArray{T, N}, L::CouplingLayerGlow; save=false) where X_ = tensor_cat(X1, X2) X = L.C.inverse(X_) - save == true ? (return X, X1, X2, Sm) : (return X) + save == true ? (return X, X1, X2, logSm, Sm) : (return X) end # Backward pass: Input (ΔY, Y), Output (ΔX, X) function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, L::CouplingLayerGlow; set_grad::Bool=true) where {T,N} # Recompute forward state - X, X1, X2, S = inverse(Y, L; save=true) + X, X1, X2, logS, S = inverse(Y, L; save=true) # Backpropagate residual ΔY1, ΔY2 = tensor_split(ΔY) @@ -148,10 +148,10 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, L::CouplingL ΔX1 = ΔY1 .* S if set_grad - ΔX2 = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), X2) + ΔY2 + ΔX2 = L.RB.backward(tensor_cat(backward(ΔS, logS, S, L.activation), ΔT), X2) + ΔY2 else - ΔX2, Δθrb = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT; ), X2; set_grad=set_grad) - _, ∇logdet = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), 0f0.*ΔT;), X2; set_grad=set_grad) + ΔX2, Δθrb = L.RB.backward(tensor_cat(backward(ΔS, logS, S, L.activation), ΔT; ), X2; set_grad=set_grad) + _, ∇logdet = L.RB.backward(tensor_cat(backward(ΔS, logS, S, L.activation), 0f0.*ΔT;), X2; set_grad=set_grad) ΔX2 += ΔY2 end ΔX_ = tensor_cat(ΔX1, ΔX2) @@ -187,7 +187,7 @@ function jacobian(ΔX::AbstractArray{T, N}, Δθ::Array{Parameter, 1}, X, L::Cou ΔlogS, ΔlogT = tensor_split(ΔlogS_T) logS, logT = tensor_split(logS_T) Sm = L.activation.forward(logS) - ΔS = L.activation.backward(ΔlogS, nothing;x=logS) + ΔS = backward(ΔlogS, logS, Sm, L.activation) Tm = logT ΔT = ΔlogT Y1 = Sm.*X1 + Tm @@ -197,7 +197,7 @@ function jacobian(ΔX::AbstractArray{T, N}, Δθ::Array{Parameter, 1}, X, L::Cou # Gauss-Newton approximation of logdet terms JΔθ,_ = tensor_split(L.RB.jacobian(cuzeros(ΔX2, size(ΔX2)), Δθ[4:end], X2)[1])#[:, :, 1:k, :] - GNΔθ = cat(0f0*Δθ[1:3], -L.RB.adjointJacobian(tensor_cat(L.activation.backward(JΔθ, Sm), zeros(Float32, size(Sm))), X2)[2]; dims=1) + GNΔθ = cat(0f0*Δθ[1:3], -L.RB.adjointJacobian(tensor_cat(backward(JΔθ, logS, Sm, L.activation), zeros(Float32, size(Sm))), X2)[2]; dims=1) L.logdet ? (return ΔY, Y, glow_logdet_forward(Sm), GNΔθ) : (return ΔY, Y) end diff --git a/src/layers/invertible_layer_irim.jl b/src/layers/invertible_layer_irim.jl index 84020cbb..155f1a49 100644 --- a/src/layers/invertible_layer_irim.jl +++ b/src/layers/invertible_layer_irim.jl @@ -65,12 +65,12 @@ end @Flux.functor CouplingLayerIRIM # 2D Constructor from input dimensions -function CouplingLayerIRIM(n_in::Int64, n_hidden::Int64; - k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, ndims=2) +function CouplingLayerIRIM(n_in::Int64, n_hidden::Int64; + k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, ndims=2, rb_activation=ReLUlayer()) # 1x1 Convolution and residual block for invertible layer C = Conv1x1(n_in) - RB = ResidualBlock(n_in÷2, n_hidden; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, ndims=ndims) + RB = ResidualBlock(n_in÷2, n_hidden; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, ndims=ndims, activation=rb_activation) return CouplingLayerIRIM(C, RB) end diff --git a/src/layers/layer_residual_block.jl b/src/layers/layer_residual_block.jl index 60fb58f3..3cdbba77 100644 --- a/src/layers/layer_residual_block.jl +++ b/src/layers/layer_residual_block.jl @@ -127,10 +127,10 @@ function forward(X1::AbstractArray{T, N}, RB::ResidualBlock; save=false) where { cdims3 = DCDims(X1, RB.W3.data; stride=RB.strides[1], padding=RB.pad[1]) Y3 = ∇conv_data(X3, RB.W3.data, cdims3) - # Return if only recomputing state - save && (return Y1, Y2, Y3) - # Finish forward - RB.fan == true ? (return RB.activation.forward(Y3)) : (return GaLU(Y3)) + + X4 = RB.fan == true ? RB.activation.forward(Y3) : GaLU(Y3) + save && (return Y1, Y2, Y3, X2, X3, X4) + return X4 end # Backward @@ -140,25 +140,25 @@ function backward(ΔX4::AbstractArray{T, N}, X1::AbstractArray{T, N}, dims = collect(1:N-1); dims[end] +=1 # Recompute forward states from input X - Y1, Y2, Y3 = forward(X1, RB; save=true) + Y1, Y2, Y3, X2, X3, X4 = forward(X1, RB; save=true) # Cdims cdims2 = DenseConvDims(Y2, RB.W2.data; stride=RB.strides[2], padding=RB.pad[2]) cdims3 = DCDims(X1, RB.W3.data; stride=RB.strides[1], padding=RB.pad[1]) # Backpropagate residual ΔX4 and compute gradients - RB.fan == true ? (ΔY3 = RB.activation.backward(ΔX4, Y3)) : (ΔY3 = GaLUgrad(ΔX4, Y3)) + RB.fan == true ? (ΔY3 = backward(ΔX4, Y3, X4, RB.activation)) : (ΔY3 = GaLUgrad(ΔX4, Y3)) ΔX3 = conv(ΔY3, RB.W3.data, cdims3) ΔW3 = ∇conv_filter(ΔY3, RB.activation.forward(Y2), cdims3) - ΔY2 = RB.activation.backward(ΔX3, Y2) + ΔY2 = backward(ΔX3, Y2, X3, RB.activation) ΔX2 = ∇conv_data(ΔY2, RB.W2.data, cdims2) + ΔY2 ΔW2 = ∇conv_filter(RB.activation.forward(Y1), ΔY2, cdims2) Δb2 = sum(ΔY2, dims=dims)[inds...] cdims1 = DenseConvDims(X1, RB.W1.data; stride=RB.strides[1], padding=RB.pad[1]) - ΔY1 = RB.activation.backward(ΔX2, Y1) + ΔY1 = backward(ΔX2, Y1, X2, RB.activation) ΔX1 = ∇conv_data(ΔY1, RB.W1.data, cdims1) ΔW1 = ∇conv_filter(X1, ΔY1, cdims1) Δb1 = sum(ΔY1, dims=dims)[inds...] @@ -187,21 +187,21 @@ function jacobian(ΔX1::AbstractArray{T, N}, Δθ::Array{Parameter, 1}, Y1 = conv(X1, RB.W1.data, cdims1) .+ reshape(RB.b1.data, inds...) ΔY1 = conv(ΔX1, RB.W1.data, cdims1) + conv(X1, Δθ[1].data, cdims1) .+ reshape(Δθ[4].data, inds...) X2 = RB.activation.forward(Y1) - ΔX2 = RB.activation.backward(ΔY1, Y1) + ΔX2 = backward(ΔY1, Y1, X2, RB.activation) cdims2 = DenseConvDims(X2, RB.W2.data; stride=RB.strides[2], padding=RB.pad[2]) Y2 = X2 + conv(X2, RB.W2.data, cdims2) .+ reshape(RB.b2.data, inds...) ΔY2 = ΔX2 + conv(ΔX2, RB.W2.data, cdims2) + conv(X2, Δθ[2].data, cdims2) .+ reshape(Δθ[5].data, inds...) X3 = RB.activation.forward(Y2) - ΔX3 = RB.activation.backward(ΔY2, Y2) + ΔX3 = backward(ΔY2, Y2, X3, RB.activation) cdims3 = DCDims(X1, RB.W3.data; nc=2*size(X1, N-1), stride=RB.strides[1], padding=RB.pad[1]) Y3 = ∇conv_data(X3, RB.W3.data, cdims3) ΔY3 = ∇conv_data(ΔX3, RB.W3.data, cdims3) + ∇conv_data(X3, Δθ[3].data, cdims3) if RB.fan == true X4 = RB.activation.forward(Y3) - ΔX4 = RB.activation.backward(ΔY3, Y3) + ΔX4 = backward(ΔY3, Y3, X4, RB.activation) else ΔX4, X4 = GaLUjacobian(ΔY3, Y3) end diff --git a/src/utils/activation_functions.jl b/src/utils/activation_functions.jl index 8695f23c..b5a17fc8 100644 --- a/src/utils/activation_functions.jl +++ b/src/utils/activation_functions.jl @@ -19,6 +19,18 @@ struct ActivationFunction backward::Function end +function backward(Δy::AbstractArray{T, N}, x::AbstractArray{T, N}, y::AbstractArray{T, N}, activation::ActivationFunction) where {T, N} + backward_activation(activation.backward, activation.inverse, Δy, x, y) +end + +function backward_activation(back::Function, inverse::Nothing, Δy::AbstractArray{T, N}, x::AbstractArray{T, N}, y::AbstractArray{T, N}) where {T, N} + back(Δy, x) +end + +function backward_activation(back::Function, inverse::Function, Δy::AbstractArray{T, N}, x::AbstractArray{T, N}, y::AbstractArray{T, N}) where {T, N} + back(Δy, y) +end + function ReLUlayer() return ActivationFunction(ReLU, nothing, ReLUgrad) end @@ -46,7 +58,7 @@ function GaLUlayer() end function ExpClampLayer() - return ActivationFunction(x -> ExpClamp(x), y -> ExpClampInv(y/2f0), (Δy, y) -> ExpClampGrad(Δy*2f0, y/2f0)) + return ActivationFunction(x -> 2 * ExpClamp(x), y -> ExpClampInv(y/2f0), (Δy, y) -> ExpClampGrad(Δy*2f0, y/2f0)) end diff --git a/test/test_layers/test_coupling_layer_irim.jl b/test/test_layers/test_coupling_layer_irim.jl index b2522a3a..f7815d2e 100644 --- a/test/test_layers/test_coupling_layer_irim.jl +++ b/test/test_layers/test_coupling_layer_irim.jl @@ -20,9 +20,9 @@ X0 = randn(Float32, nx, ny, n_in, batchsize) dX = X - X0 # Invertible layers -L = CouplingLayerIRIM(n_in, n_hidden) -L01 = CouplingLayerIRIM(n_in, n_hidden) -L02 = CouplingLayerIRIM(n_in, n_hidden) +L = CouplingLayerIRIM(n_in, n_hidden; rb_activation=SigmoidLayer()) +L01 = CouplingLayerIRIM(n_in, n_hidden; rb_activation=SigmoidLayer()) +L02 = CouplingLayerIRIM(n_in, n_hidden; rb_activation=SigmoidLayer()) ################################################################################################### # Test invertibility @@ -131,9 +131,9 @@ end # Gradient test # Initialization -L = CouplingLayerIRIM(n_in, n_hidden) +L = CouplingLayerIRIM(n_in, n_hidden; rb_activation=SigmoidLayer()) θ = deepcopy(get_params(L)) -L0 = CouplingLayerIRIM(n_in, n_hidden) +L0 = CouplingLayerIRIM(n_in, n_hidden; rb_activation=SigmoidLayer()) θ0 = deepcopy(get_params(L0)) X = randn(Float32, nx, ny, n_in, batchsize) diff --git a/test/test_networks/test_conditional_glow_network.jl b/test/test_networks/test_conditional_glow_network.jl index 946b39b0..f300435e 100644 --- a/test/test_networks/test_conditional_glow_network.jl +++ b/test/test_networks/test_conditional_glow_network.jl @@ -24,7 +24,7 @@ N = (nx,ny) # Invertibility # Network and input -G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales,ndims=length(N)) |> device +G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales,ndims=length(N), rb_activation=LeakyReLUlayer()) |> device X = rand(Float32, N..., n_in, batchsize) |> device Cond = rand(Float32, N..., n_cond, batchsize) |> device @@ -67,7 +67,7 @@ N = (nx,ny) # Invertibility # Network and input -G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) |> device +G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N), rb_activation=LeakyReLUlayer()) |> device X = rand(Float32, N..., n_in, batchsize) |> device Cond = rand(Float32, N..., n_cond, batchsize) |> device @@ -106,7 +106,7 @@ end # Gradient test w.r.t. input -G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) |> device +G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N), rb_activation=LeakyReLUlayer()) |> device X = rand(Float32, N..., n_in, batchsize) |> device Cond = rand(Float32, N..., n_cond, batchsize) |> device X0 = rand(Float32, N..., n_in, batchsize) |> device @@ -135,8 +135,8 @@ end # Gradient test w.r.t. parameters X = rand(Float32, N..., n_in, batchsize) |> device -G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) |> device -G0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) |> device +G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N), rb_activation=LeakyReLUlayer()) |> device +G0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N), rb_activation=LeakyReLUlayer()) |> device Gini = deepcopy(G0) # Test one parameter from residual block and 1x1 conv @@ -171,7 +171,7 @@ end sum_net = ResNet(n_cond, 16, 3; norm=nothing) # make sure it doesnt have any weird normalizations # Network and input -flow = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales,ndims=length(N)) +flow = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales,ndims=length(N), rb_activation=LeakyReLUlayer()) G = SummarizedNet(flow, sum_net) |> device X = rand(Float32, N..., n_in, batchsize) |> device; @@ -238,7 +238,7 @@ end # Gradient test w.r.t. parameters X = rand(Float32, N..., n_in, batchsize) |> device -flow0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales,ndims=length(N)) |> device +flow0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales,ndims=length(N), rb_activation=LeakyReLUlayer()) |> device G0 = SummarizedNet(flow0, sum_net) |> device Gini = deepcopy(G0) @@ -273,7 +273,7 @@ N = (nx,ny,nz) # Invertibility # Network and input -G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) |> device +G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N), rb_activation=LeakyReLUlayer()) |> device X = rand(Float32, N..., n_in, batchsize) |> device Cond = rand(Float32, N..., n_cond, batchsize) |> device @@ -304,7 +304,7 @@ end # Gradient test w.r.t. input -G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) |> device +G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N), rb_activation=LeakyReLUlayer()) |> device X = rand(Float32, N..., n_in, batchsize) |> device Cond = rand(Float32, N..., n_cond, batchsize) |> device X0 = rand(Float32, N..., n_in, batchsize) |> device @@ -333,8 +333,8 @@ end # Gradient test w.r.t. parameters X = rand(Float32, N..., n_in, batchsize) |> device -G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) |> device -G0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) |> device +G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N), rb_activation=LeakyReLUlayer()) |> device +G0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N), rb_activation=LeakyReLUlayer()) |> device Gini = deepcopy(G0) # Test one parameter from residual block and 1x1 conv @@ -368,7 +368,7 @@ end sum_net_3d = ResNet(n_cond, 16, 3; ndims=3, norm=nothing) |> device# make sure it doesnt have any weird normalizati8ons # Network and input -flow = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales,ndims=length(N)) |> device; +flow = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales,ndims=length(N), rb_activation=LeakyReLUlayer()) |> device; G = SummarizedNet(flow, sum_net_3d) |> device X = rand(Float32, N..., n_in, batchsize) |> device; @@ -428,7 +428,7 @@ end # Gradient test w.r.t. parameters X = rand(Float32, N..., n_in, batchsize) |> device -flow0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales,ndims=length(N)) |> device +flow0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales,ndims=length(N), rb_activation=LeakyReLUlayer()) |> device G0 = SummarizedNet(flow0, sum_net_3d) |> device Gini = deepcopy(G0) diff --git a/test/test_networks/test_multiscale_conditional_hint_network.jl b/test/test_networks/test_multiscale_conditional_hint_network.jl index 260906ee..d94536aa 100644 --- a/test/test_networks/test_multiscale_conditional_hint_network.jl +++ b/test/test_networks/test_multiscale_conditional_hint_network.jl @@ -76,7 +76,7 @@ function grad_test_X(nx, ny, n_channel, batchsize, logdet, squeeze_type, split_s f0, gX, gY = loss(CH, X0, Y0)[1:3] maxiter = 5 - h = 0.1f0 + h = 5f-2 err1 = zeros(Float32, maxiter) err2 = zeros(Float32, maxiter)