Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/conditional_layers/conditional_layer_glow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,14 @@ function inverse(Y::AbstractArray{T, N}, C::AbstractArray{T, N}, L::ConditionalL
X_ = tensor_cat(X1, X2)
X = L.C.inverse(X_)

save == true ? (return X, X1, X2, Sm) : (return X)
save == true ? (return X, X1, X2, logS, Sm) : (return X)
end

# Backward pass: Input (ΔY, Y), Output (ΔX, X)
function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, C::AbstractArray{T, N}, L::ConditionalLayerGlow;) where {T,N}

# Recompute forward state
X, X1, X2, S = inverse(Y, C, L; save=true)
X, X1, X2, logS, S = inverse(Y, C, L; save=true)

# Backpropagate residual
ΔY1, ΔY2 = tensor_split(ΔY)
Expand All @@ -147,7 +147,7 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, C::AbstractA
end

# Backpropagate RB
ΔX2_ΔC = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), (tensor_cat(X2, C)))
ΔX2_ΔC = L.RB.backward(tensor_cat(backward(ΔS, logS, S, L.activation), ΔT), (tensor_cat(X2, C)))
ΔX2, ΔC = tensor_split(ΔX2_ΔC; split_index=size(ΔY2)[N-1])
ΔX2 += ΔY2

Expand Down
26 changes: 13 additions & 13 deletions src/layers/invertible_layer_basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ function forward(X1::AbstractArray{T, N}, X2::AbstractArray{T, N}, L::CouplingLa
Y2 = S.*X2 + logS_T2

if logdet
save ? (return X1, Y2, coupling_logdet_forward(S), S) : (return X1, Y2, coupling_logdet_forward(S))
save ? (return X1, Y2, coupling_logdet_forward(S), logS_T1, S) : (return X1, Y2, coupling_logdet_forward(S))
else
save ? (return X1, Y2, S) : (return X1, Y2)
save ? (return X1, Y2, logS_T1, S) : (return X1, Y2)
end
end

Expand All @@ -112,17 +112,17 @@ function inverse(Y1::AbstractArray{T, N}, Y2::AbstractArray{T, N}, L::CouplingLa
X2 = (Y2 - logS_T2) ./ (S .+ eps(T)) # add epsilon to avoid division by 0

if logdet
save == true ? (return Y1, X2, -coupling_logdet_forward(S), S) : (return Y1, X2, -coupling_logdet_forward(S))
save == true ? (return Y1, X2, -coupling_logdet_forward(S), logS_T1, S) : (return Y1, X2, -coupling_logdet_forward(S))
else
save == true ? (return Y1, X2, S) : (return Y1, X2)
save == true ? (return Y1, X2, logS_T1, S) : (return Y1, X2)
end
end

# 2D/3D Backward pass: Input (ΔY, Y), Output (ΔX, X)
function backward(ΔY1::AbstractArray{T, N}, ΔY2::AbstractArray{T, N}, Y1::AbstractArray{T, N}, Y2::AbstractArray{T, N}, L::CouplingLayerBasic; set_grad::Bool=true) where {T, N}

# Recompute forward state
X1, X2, S = inverse(Y1, Y2, L; save=true, logdet=false)
X1, X2, logS_T1, S = inverse(Y1, Y2, L; save=true, logdet=false)

# Backpropagate residual
ΔT = copy(ΔY2)
Expand All @@ -132,11 +132,11 @@ function backward(ΔY1::AbstractArray{T, N}, ΔY2::AbstractArray{T, N}, Y1::Abst
end
ΔX2 = ΔY2 .* S
if set_grad
ΔX1 = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), X1) + ΔY1
ΔX1 = L.RB.backward(tensor_cat(backward(ΔS, logS_T1, S, L.activation), ΔT), X1) + ΔY1
else
ΔX1, Δθ = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), X1; set_grad=set_grad)
ΔX1, Δθ = L.RB.backward(tensor_cat(backward(ΔS, logS_T1, S, L.activation), ΔT), X1; set_grad=set_grad)
if L.logdet
_, ∇logdet = L.RB.backward(tensor_cat(L.activation.backward(coupling_logdet_backward(S), S), 0 .*ΔT), X1; set_grad=set_grad)
_, ∇logdet = L.RB.backward(tensor_cat(backward(coupling_logdet_backward(S), logS_T1, S, L.activation), 0 .*ΔT), X1; set_grad=set_grad)
end
ΔX1 += ΔY1
end
Expand All @@ -152,7 +152,7 @@ end
function backward_inv(ΔX1::AbstractArray{T, N}, ΔX2::AbstractArray{T, N}, X1::AbstractArray{T, N}, X2::AbstractArray{T, N}, L::CouplingLayerBasic; set_grad::Bool=true) where {T, N}

# Recompute inverse state
Y1, Y2, S = forward(X1, X2, L; save=true, logdet=false)
Y1, Y2, logS_T1, S = forward(X1, X2, L; save=true, logdet=false)

# Backpropagate residual
ΔT = -ΔX2 ./ S
Expand All @@ -161,9 +161,9 @@ function backward_inv(ΔX1::AbstractArray{T, N}, ΔX2::AbstractArray{T, N}, X1::
set_grad ? (ΔS += coupling_logdet_backward(S)) : (∇logdet = -coupling_logdet_backward(S))
end
if set_grad
ΔY1 = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), Y1) + ΔX1
ΔY1 = L.RB.backward(tensor_cat(backward(ΔS, logS_T1, S, L.activation), ΔT), Y1) + ΔX1
else
ΔY1, Δθ = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), Y1; set_grad=set_grad)
ΔY1, Δθ = L.RB.backward(tensor_cat(backward(ΔS, logS_T1, S, L.activation), ΔT), Y1; set_grad=set_grad)
ΔY1 += ΔX1
end
ΔY2 = - ΔT
Expand All @@ -187,14 +187,14 @@ function jacobian(ΔX1::AbstractArray{T, N}, ΔX2::AbstractArray{T, N}, Δθ::Ab
logS_T1, logS_T2 = tensor_split(L.RB.forward(X1))
ΔlogS_T1, ΔlogS_T2 = tensor_split(jacobian(ΔX1, Δθ, X1, L.RB)[1])
S = L.activation.forward(logS_T1)
ΔS = L.activation.backward(ΔlogS_T1, S)
ΔS = backward(ΔlogS_T1, logS_T1, S, L.activation)
Y2 = S.*X2 + logS_T2
ΔY2 = ΔS.*X2 + S.*ΔX2 + ΔlogS_T2

if logdet
# Gauss-Newton approximation of logdet terms
JΔθ = tensor_split(L.RB.jacobian(zeros(Float32, size(ΔX1)), Δθ, X1)[1])[1]
GNΔθ = -L.RB.adjointJacobian(tensor_cat(L.activation.backward(JΔθ, S), zeros(Float32, size(S))), X1)[2]
GNΔθ = -L.RB.adjointJacobian(tensor_cat(backward(JΔθ, logS_T1, S, L.activation), zeros(Float32, size(S))), X1)[2]

save ? (return ΔX1, ΔY2, X1, Y2, coupling_logdet_forward(S), GNΔθ, S) : (return ΔX1, ΔY2, X1, Y2, coupling_logdet_forward(S), GNΔθ)
else
Expand Down
14 changes: 7 additions & 7 deletions src/layers/invertible_layer_glow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,14 @@ function inverse(Y::AbstractArray{T, N}, L::CouplingLayerGlow; save=false) where
X_ = tensor_cat(X1, X2)
X = L.C.inverse(X_)

save == true ? (return X, X1, X2, Sm) : (return X)
save == true ? (return X, X1, X2, logSm, Sm) : (return X)
end

# Backward pass: Input (ΔY, Y), Output (ΔX, X)
function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, L::CouplingLayerGlow; set_grad::Bool=true) where {T,N}

# Recompute forward state
X, X1, X2, S = inverse(Y, L; save=true)
X, X1, X2, logS, S = inverse(Y, L; save=true)

# Backpropagate residual
ΔY1, ΔY2 = tensor_split(ΔY)
Expand All @@ -148,10 +148,10 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, L::CouplingL

ΔX1 = ΔY1 .* S
if set_grad
ΔX2 = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), X2) + ΔY2
ΔX2 = L.RB.backward(tensor_cat(backward(ΔS, logS, S, L.activation), ΔT), X2) + ΔY2
else
ΔX2, Δθrb = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT; ), X2; set_grad=set_grad)
_, ∇logdet = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), 0f0.*ΔT;), X2; set_grad=set_grad)
ΔX2, Δθrb = L.RB.backward(tensor_cat(backward(ΔS, logS, S, L.activation), ΔT; ), X2; set_grad=set_grad)
_, ∇logdet = L.RB.backward(tensor_cat(backward(ΔS, logS, S, L.activation), 0f0.*ΔT;), X2; set_grad=set_grad)
ΔX2 += ΔY2
end
ΔX_ = tensor_cat(ΔX1, ΔX2)
Expand Down Expand Up @@ -187,7 +187,7 @@ function jacobian(ΔX::AbstractArray{T, N}, Δθ::Array{Parameter, 1}, X, L::Cou
ΔlogS, ΔlogT = tensor_split(ΔlogS_T)
logS, logT = tensor_split(logS_T)
Sm = L.activation.forward(logS)
ΔS = L.activation.backward(ΔlogS, nothing;x=logS)
ΔS = backward(ΔlogS, logS, Sm, L.activation)
Tm = logT
ΔT = ΔlogT
Y1 = Sm.*X1 + Tm
Expand All @@ -197,7 +197,7 @@ function jacobian(ΔX::AbstractArray{T, N}, Δθ::Array{Parameter, 1}, X, L::Cou

# Gauss-Newton approximation of logdet terms
JΔθ,_ = tensor_split(L.RB.jacobian(cuzeros(ΔX2, size(ΔX2)), Δθ[4:end], X2)[1])#[:, :, 1:k, :]
GNΔθ = cat(0f0*Δθ[1:3], -L.RB.adjointJacobian(tensor_cat(L.activation.backward(JΔθ, Sm), zeros(Float32, size(Sm))), X2)[2]; dims=1)
GNΔθ = cat(0f0*Δθ[1:3], -L.RB.adjointJacobian(tensor_cat(backward(JΔθ, logS, Sm, L.activation), zeros(Float32, size(Sm))), X2)[2]; dims=1)

L.logdet ? (return ΔY, Y, glow_logdet_forward(Sm), GNΔθ) : (return ΔY, Y)
end
Expand Down
6 changes: 3 additions & 3 deletions src/layers/invertible_layer_irim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ end
@Flux.functor CouplingLayerIRIM

# 2D Constructor from input dimensions
function CouplingLayerIRIM(n_in::Int64, n_hidden::Int64;
k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, ndims=2)
function CouplingLayerIRIM(n_in::Int64, n_hidden::Int64;
k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, ndims=2, rb_activation=ReLUlayer())

# 1x1 Convolution and residual block for invertible layer
C = Conv1x1(n_in)
RB = ResidualBlock(n_in÷2, n_hidden; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, ndims=ndims)
RB = ResidualBlock(n_in÷2, n_hidden; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, ndims=ndims, activation=rb_activation)

return CouplingLayerIRIM(C, RB)
end
Expand Down
22 changes: 11 additions & 11 deletions src/layers/layer_residual_block.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@ function forward(X1::AbstractArray{T, N}, RB::ResidualBlock; save=false) where {

cdims3 = DCDims(X1, RB.W3.data; stride=RB.strides[1], padding=RB.pad[1])
Y3 = ∇conv_data(X3, RB.W3.data, cdims3)
# Return if only recomputing state
save && (return Y1, Y2, Y3)
# Finish forward
RB.fan == true ? (return RB.activation.forward(Y3)) : (return GaLU(Y3))

X4 = RB.fan == true ? RB.activation.forward(Y3) : GaLU(Y3)
save && (return Y1, Y2, Y3, X2, X3, X4)
return X4
end

# Backward
Expand All @@ -140,25 +140,25 @@ function backward(ΔX4::AbstractArray{T, N}, X1::AbstractArray{T, N},
dims = collect(1:N-1); dims[end] +=1

# Recompute forward states from input X
Y1, Y2, Y3 = forward(X1, RB; save=true)
Y1, Y2, Y3, X2, X3, X4 = forward(X1, RB; save=true)

# Cdims
cdims2 = DenseConvDims(Y2, RB.W2.data; stride=RB.strides[2], padding=RB.pad[2])
cdims3 = DCDims(X1, RB.W3.data; stride=RB.strides[1], padding=RB.pad[1])

# Backpropagate residual ΔX4 and compute gradients
RB.fan == true ? (ΔY3 = RB.activation.backward(ΔX4, Y3)) : (ΔY3 = GaLUgrad(ΔX4, Y3))
RB.fan == true ? (ΔY3 = backward(ΔX4, Y3, X4, RB.activation)) : (ΔY3 = GaLUgrad(ΔX4, Y3))
ΔX3 = conv(ΔY3, RB.W3.data, cdims3)
ΔW3 = ∇conv_filter(ΔY3, RB.activation.forward(Y2), cdims3)

ΔY2 = RB.activation.backward(ΔX3, Y2)
ΔY2 = backward(ΔX3, Y2, X3, RB.activation)
ΔX2 = ∇conv_data(ΔY2, RB.W2.data, cdims2) + ΔY2
ΔW2 = ∇conv_filter(RB.activation.forward(Y1), ΔY2, cdims2)
Δb2 = sum(ΔY2, dims=dims)[inds...]

cdims1 = DenseConvDims(X1, RB.W1.data; stride=RB.strides[1], padding=RB.pad[1])

ΔY1 = RB.activation.backward(ΔX2, Y1)
ΔY1 = backward(ΔX2, Y1, X2, RB.activation)
ΔX1 = ∇conv_data(ΔY1, RB.W1.data, cdims1)
ΔW1 = ∇conv_filter(X1, ΔY1, cdims1)
Δb1 = sum(ΔY1, dims=dims)[inds...]
Expand Down Expand Up @@ -187,21 +187,21 @@ function jacobian(ΔX1::AbstractArray{T, N}, Δθ::Array{Parameter, 1},
Y1 = conv(X1, RB.W1.data, cdims1) .+ reshape(RB.b1.data, inds...)
ΔY1 = conv(ΔX1, RB.W1.data, cdims1) + conv(X1, Δθ[1].data, cdims1) .+ reshape(Δθ[4].data, inds...)
X2 = RB.activation.forward(Y1)
ΔX2 = RB.activation.backward(ΔY1, Y1)
ΔX2 = backward(ΔY1, Y1, X2, RB.activation)

cdims2 = DenseConvDims(X2, RB.W2.data; stride=RB.strides[2], padding=RB.pad[2])

Y2 = X2 + conv(X2, RB.W2.data, cdims2) .+ reshape(RB.b2.data, inds...)
ΔY2 = ΔX2 + conv(ΔX2, RB.W2.data, cdims2) + conv(X2, Δθ[2].data, cdims2) .+ reshape(Δθ[5].data, inds...)
X3 = RB.activation.forward(Y2)
ΔX3 = RB.activation.backward(ΔY2, Y2)
ΔX3 = backward(ΔY2, Y2, X3, RB.activation)

cdims3 = DCDims(X1, RB.W3.data; nc=2*size(X1, N-1), stride=RB.strides[1], padding=RB.pad[1])
Y3 = ∇conv_data(X3, RB.W3.data, cdims3)
ΔY3 = ∇conv_data(ΔX3, RB.W3.data, cdims3) + ∇conv_data(X3, Δθ[3].data, cdims3)
if RB.fan == true
X4 = RB.activation.forward(Y3)
ΔX4 = RB.activation.backward(ΔY3, Y3)
ΔX4 = backward(ΔY3, Y3, X4, RB.activation)
else
ΔX4, X4 = GaLUjacobian(ΔY3, Y3)
end
Expand Down
14 changes: 13 additions & 1 deletion src/utils/activation_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@ struct ActivationFunction
backward::Function
end

function backward(Δy::AbstractArray{T, N}, x::AbstractArray{T, N}, y::AbstractArray{T, N}, activation::ActivationFunction) where {T, N}
backward_activation(activation.backward, activation.inverse, Δy, x, y)
end

function backward_activation(back::Function, inverse::Nothing, Δy::AbstractArray{T, N}, x::AbstractArray{T, N}, y::AbstractArray{T, N}) where {T, N}
back(Δy, x)
end

function backward_activation(back::Function, inverse::Function, Δy::AbstractArray{T, N}, x::AbstractArray{T, N}, y::AbstractArray{T, N}) where {T, N}
back(Δy, y)
end

function ReLUlayer()
return ActivationFunction(ReLU, nothing, ReLUgrad)
end
Expand Down Expand Up @@ -46,7 +58,7 @@ function GaLUlayer()
end

function ExpClampLayer()
return ActivationFunction(x -> ExpClamp(x), y -> ExpClampInv(y/2f0), (Δy, y) -> ExpClampGrad(Δy*2f0, y/2f0))
return ActivationFunction(x -> 2 * ExpClamp(x), y -> ExpClampInv(y/2f0), (Δy, y) -> ExpClampGrad(Δy*2f0, y/2f0))
end


Expand Down
10 changes: 5 additions & 5 deletions test/test_layers/test_coupling_layer_irim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ X0 = randn(Float32, nx, ny, n_in, batchsize)
dX = X - X0

# Invertible layers
L = CouplingLayerIRIM(n_in, n_hidden)
L01 = CouplingLayerIRIM(n_in, n_hidden)
L02 = CouplingLayerIRIM(n_in, n_hidden)
L = CouplingLayerIRIM(n_in, n_hidden; rb_activation=SigmoidLayer())
L01 = CouplingLayerIRIM(n_in, n_hidden; rb_activation=SigmoidLayer())
L02 = CouplingLayerIRIM(n_in, n_hidden; rb_activation=SigmoidLayer())

###################################################################################################
# Test invertibility
Expand Down Expand Up @@ -131,9 +131,9 @@ end
# Gradient test

# Initialization
L = CouplingLayerIRIM(n_in, n_hidden)
L = CouplingLayerIRIM(n_in, n_hidden; rb_activation=SigmoidLayer())
θ = deepcopy(get_params(L))
L0 = CouplingLayerIRIM(n_in, n_hidden)
L0 = CouplingLayerIRIM(n_in, n_hidden; rb_activation=SigmoidLayer())
θ0 = deepcopy(get_params(L0))
X = randn(Float32, nx, ny, n_in, batchsize)

Expand Down
Loading
Loading