Skip to content

Commit f4b423e

Browse files
committed
Added backward implementation for ActivationFunction so that the user doesn't need to know if the function is invertible. Fixed ExpClampLayer activation.
1 parent daa2ea9 commit f4b423e

9 files changed

Lines changed: 69 additions & 57 deletions

src/conditional_layers/conditional_layer_glow.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,14 @@ function inverse(Y::AbstractArray{T, N}, C::AbstractArray{T, N}, L::ConditionalL
127127
X_ = tensor_cat(X1, X2)
128128
X = L.C.inverse(X_)
129129

130-
save == true ? (return X, X1, X2, Sm) : (return X)
130+
save == true ? (return X, X1, X2, logS, Sm) : (return X)
131131
end
132132

133133
# Backward pass: Input (ΔY, Y), Output (ΔX, X)
134134
function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, C::AbstractArray{T, N}, L::ConditionalLayerGlow;) where {T,N}
135135

136136
# Recompute forward state
137-
X, X1, X2, S = inverse(Y, C, L; save=true)
137+
X, X1, X2, logS, S = inverse(Y, C, L; save=true)
138138

139139
# Backpropagate residual
140140
ΔY1, ΔY2 = tensor_split(ΔY)
@@ -147,7 +147,7 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, C::AbstractA
147147
end
148148

149149
# Backpropagate RB
150-
ΔX2_ΔC = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), (tensor_cat(X2, C)))
150+
ΔX2_ΔC = L.RB.backward(tensor_cat(backward(ΔS, logS, S, L.activation), ΔT), (tensor_cat(X2, C)))
151151
ΔX2, ΔC = tensor_split(ΔX2_ΔC; split_index=size(ΔY2)[N-1])
152152
ΔX2 += ΔY2
153153

src/layers/invertible_layer_basic.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ function forward(X1::AbstractArray{T, N}, X2::AbstractArray{T, N}, L::CouplingLa
9696
Y2 = S.*X2 + logS_T2
9797

9898
if logdet
99-
save ? (return X1, Y2, coupling_logdet_forward(S), S) : (return X1, Y2, coupling_logdet_forward(S))
99+
save ? (return X1, Y2, coupling_logdet_forward(S), logS_T1, S) : (return X1, Y2, coupling_logdet_forward(S))
100100
else
101-
save ? (return X1, Y2, S) : (return X1, Y2)
101+
save ? (return X1, Y2, logS_T1, S) : (return X1, Y2)
102102
end
103103
end
104104

@@ -112,17 +112,17 @@ function inverse(Y1::AbstractArray{T, N}, Y2::AbstractArray{T, N}, L::CouplingLa
112112
X2 = (Y2 - logS_T2) ./ (S .+ eps(T)) # add epsilon to avoid division by 0
113113

114114
if logdet
115-
save == true ? (return Y1, X2, -coupling_logdet_forward(S), S) : (return Y1, X2, -coupling_logdet_forward(S))
115+
save == true ? (return Y1, X2, -coupling_logdet_forward(S), logS_T1, S) : (return Y1, X2, -coupling_logdet_forward(S))
116116
else
117-
save == true ? (return Y1, X2, S) : (return Y1, X2)
117+
save == true ? (return Y1, X2, logS_T1, S) : (return Y1, X2)
118118
end
119119
end
120120

121121
# 2D/3D Backward pass: Input (ΔY, Y), Output (ΔX, X)
122122
function backward(ΔY1::AbstractArray{T, N}, ΔY2::AbstractArray{T, N}, Y1::AbstractArray{T, N}, Y2::AbstractArray{T, N}, L::CouplingLayerBasic; set_grad::Bool=true) where {T, N}
123123

124124
# Recompute forward state
125-
X1, X2, S = inverse(Y1, Y2, L; save=true, logdet=false)
125+
X1, X2, logS_T1, S = inverse(Y1, Y2, L; save=true, logdet=false)
126126

127127
# Backpropagate residual
128128
ΔT = copy(ΔY2)
@@ -132,11 +132,11 @@ function backward(ΔY1::AbstractArray{T, N}, ΔY2::AbstractArray{T, N}, Y1::Abst
132132
end
133133
ΔX2 = ΔY2 .* S
134134
if set_grad
135-
ΔX1 = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), X1) + ΔY1
135+
ΔX1 = L.RB.backward(tensor_cat(backward(ΔS, logS_T1, S, L.activation), ΔT), X1) + ΔY1
136136
else
137-
ΔX1, Δθ = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), X1; set_grad=set_grad)
137+
ΔX1, Δθ = L.RB.backward(tensor_cat(backward(ΔS, logS_T1, S, L.activation), ΔT), X1; set_grad=set_grad)
138138
if L.logdet
139-
_, ∇logdet = L.RB.backward(tensor_cat(L.activation.backward(coupling_logdet_backward(S), S), 0 .*ΔT), X1; set_grad=set_grad)
139+
_, ∇logdet = L.RB.backward(tensor_cat(backward(coupling_logdet_backward(S), logS_T1, S, L.activation), 0 .*ΔT), X1; set_grad=set_grad)
140140
end
141141
ΔX1 += ΔY1
142142
end
@@ -152,7 +152,7 @@ end
152152
function backward_inv(ΔX1::AbstractArray{T, N}, ΔX2::AbstractArray{T, N}, X1::AbstractArray{T, N}, X2::AbstractArray{T, N}, L::CouplingLayerBasic; set_grad::Bool=true) where {T, N}
153153

154154
# Recompute inverse state
155-
Y1, Y2, S = forward(X1, X2, L; save=true, logdet=false)
155+
Y1, Y2, logS_T1, S = forward(X1, X2, L; save=true, logdet=false)
156156

157157
# Backpropagate residual
158158
ΔT = -ΔX2 ./ S
@@ -161,9 +161,9 @@ function backward_inv(ΔX1::AbstractArray{T, N}, ΔX2::AbstractArray{T, N}, X1::
161161
set_grad ? (ΔS += coupling_logdet_backward(S)) : (∇logdet = -coupling_logdet_backward(S))
162162
end
163163
if set_grad
164-
ΔY1 = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), Y1) + ΔX1
164+
ΔY1 = L.RB.backward(tensor_cat(backward(ΔS, logS_T1, S, L.activation), ΔT), Y1) + ΔX1
165165
else
166-
ΔY1, Δθ = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), Y1; set_grad=set_grad)
166+
ΔY1, Δθ = L.RB.backward(tensor_cat(backward(ΔS, logS_T1, S, L.activation), ΔT), Y1; set_grad=set_grad)
167167
ΔY1 += ΔX1
168168
end
169169
ΔY2 = - ΔT
@@ -187,14 +187,14 @@ function jacobian(ΔX1::AbstractArray{T, N}, ΔX2::AbstractArray{T, N}, Δθ::Ab
187187
logS_T1, logS_T2 = tensor_split(L.RB.forward(X1))
188188
ΔlogS_T1, ΔlogS_T2 = tensor_split(jacobian(ΔX1, Δθ, X1, L.RB)[1])
189189
S = L.activation.forward(logS_T1)
190-
ΔS = L.activation.backward(ΔlogS_T1, S)
190+
ΔS = backward(ΔlogS_T1, logS_T1, S, L.activation)
191191
Y2 = S.*X2 + logS_T2
192192
ΔY2 = ΔS.*X2 + S.*ΔX2 + ΔlogS_T2
193193

194194
if logdet
195195
# Gauss-Newton approximation of logdet terms
196196
JΔθ = tensor_split(L.RB.jacobian(zeros(Float32, size(ΔX1)), Δθ, X1)[1])[1]
197-
GNΔθ = -L.RB.adjointJacobian(tensor_cat(L.activation.backward(JΔθ, S), zeros(Float32, size(S))), X1)[2]
197+
GNΔθ = -L.RB.adjointJacobian(tensor_cat(backward(JΔθ, logS_T1, S, L.activation), zeros(Float32, size(S))), X1)[2]
198198

199199
save ? (return ΔX1, ΔY2, X1, Y2, coupling_logdet_forward(S), GNΔθ, S) : (return ΔX1, ΔY2, X1, Y2, coupling_logdet_forward(S), GNΔθ)
200200
else

src/layers/invertible_layer_glow.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,14 +129,14 @@ function inverse(Y::AbstractArray{T, N}, L::CouplingLayerGlow; save=false) where
129129
X_ = tensor_cat(X1, X2)
130130
X = L.C.inverse(X_)
131131

132-
save == true ? (return X, X1, X2, Sm) : (return X)
132+
save == true ? (return X, X1, X2, logSm, Sm) : (return X)
133133
end
134134

135135
# Backward pass: Input (ΔY, Y), Output (ΔX, X)
136136
function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, L::CouplingLayerGlow; set_grad::Bool=true) where {T,N}
137137

138138
# Recompute forward state
139-
X, X1, X2, S = inverse(Y, L; save=true)
139+
X, X1, X2, logS, S = inverse(Y, L; save=true)
140140

141141
# Backpropagate residual
142142
ΔY1, ΔY2 = tensor_split(ΔY)
@@ -148,10 +148,10 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, L::CouplingL
148148

149149
ΔX1 = ΔY1 .* S
150150
if set_grad
151-
ΔX2 = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), X2) + ΔY2
151+
ΔX2 = L.RB.backward(tensor_cat(backward(ΔS, logS, S, L.activation), ΔT), X2) + ΔY2
152152
else
153-
ΔX2, Δθrb = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT; ), X2; set_grad=set_grad)
154-
_, ∇logdet = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), 0f0.*ΔT;), X2; set_grad=set_grad)
153+
ΔX2, Δθrb = L.RB.backward(tensor_cat(backward(ΔS, logS, S, L.activation), ΔT; ), X2; set_grad=set_grad)
154+
_, ∇logdet = L.RB.backward(tensor_cat(backward(ΔS, logS, S, L.activation), 0f0.*ΔT;), X2; set_grad=set_grad)
155155
ΔX2 += ΔY2
156156
end
157157
ΔX_ = tensor_cat(ΔX1, ΔX2)
@@ -187,7 +187,7 @@ function jacobian(ΔX::AbstractArray{T, N}, Δθ::Array{Parameter, 1}, X, L::Cou
187187
ΔlogS, ΔlogT = tensor_split(ΔlogS_T)
188188
logS, logT = tensor_split(logS_T)
189189
Sm = L.activation.forward(logS)
190-
ΔS = L.activation.backward(ΔlogS, nothing;x=logS)
190+
ΔS = backward(ΔlogS, logS, Sm, L.activation)
191191
Tm = logT
192192
ΔT = ΔlogT
193193
Y1 = Sm.*X1 + Tm
@@ -197,7 +197,7 @@ function jacobian(ΔX::AbstractArray{T, N}, Δθ::Array{Parameter, 1}, X, L::Cou
197197

198198
# Gauss-Newton approximation of logdet terms
199199
JΔθ,_ = tensor_split(L.RB.jacobian(cuzeros(ΔX2, size(ΔX2)), Δθ[4:end], X2)[1])#[:, :, 1:k, :]
200-
GNΔθ = cat(0f0*Δθ[1:3], -L.RB.adjointJacobian(tensor_cat(L.activation.backward(JΔθ, Sm), zeros(Float32, size(Sm))), X2)[2]; dims=1)
200+
GNΔθ = cat(0f0*Δθ[1:3], -L.RB.adjointJacobian(tensor_cat(backward(JΔθ, logS, Sm, L.activation), zeros(Float32, size(Sm))), X2)[2]; dims=1)
201201

202202
L.logdet ? (return ΔY, Y, glow_logdet_forward(Sm), GNΔθ) : (return ΔY, Y)
203203
end

src/layers/invertible_layer_irim.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,12 @@ end
6565
@Flux.functor CouplingLayerIRIM
6666

6767
# 2D Constructor from input dimensions
68-
function CouplingLayerIRIM(n_in::Int64, n_hidden::Int64;
69-
k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, ndims=2)
68+
function CouplingLayerIRIM(n_in::Int64, n_hidden::Int64;
69+
k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, ndims=2, rb_activation=ReLUlayer())
7070

7171
# 1x1 Convolution and residual block for invertible layer
7272
C = Conv1x1(n_in)
73-
RB = ResidualBlock(n_in÷2, n_hidden; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, ndims=ndims)
73+
RB = ResidualBlock(n_in÷2, n_hidden; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, ndims=ndims, activation=rb_activation)
7474

7575
return CouplingLayerIRIM(C, RB)
7676
end

src/layers/layer_residual_block.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,10 @@ function forward(X1::AbstractArray{T, N}, RB::ResidualBlock; save=false) where {
127127

128128
cdims3 = DCDims(X1, RB.W3.data; stride=RB.strides[1], padding=RB.pad[1])
129129
Y3 = ∇conv_data(X3, RB.W3.data, cdims3)
130-
# Return if only recomputing state
131-
save && (return Y1, Y2, Y3)
132-
# Finish forward
133-
RB.fan == true ? (return RB.activation.forward(Y3)) : (return GaLU(Y3))
130+
131+
X4 = RB.fan == true ? RB.activation.forward(Y3) : GaLU(Y3)
132+
save && (return Y1, Y2, Y3, X2, X3, X4)
133+
return X4
134134
end
135135

136136
# Backward
@@ -140,25 +140,25 @@ function backward(ΔX4::AbstractArray{T, N}, X1::AbstractArray{T, N},
140140
dims = collect(1:N-1); dims[end] +=1
141141

142142
# Recompute forward states from input X
143-
Y1, Y2, Y3 = forward(X1, RB; save=true)
143+
Y1, Y2, Y3, X2, X3, X4 = forward(X1, RB; save=true)
144144

145145
# Cdims
146146
cdims2 = DenseConvDims(Y2, RB.W2.data; stride=RB.strides[2], padding=RB.pad[2])
147147
cdims3 = DCDims(X1, RB.W3.data; stride=RB.strides[1], padding=RB.pad[1])
148148

149149
# Backpropagate residual ΔX4 and compute gradients
150-
RB.fan == true ? (ΔY3 = RB.activation.backward(ΔX4, Y3)) : (ΔY3 = GaLUgrad(ΔX4, Y3))
150+
RB.fan == true ? (ΔY3 = backward(ΔX4, Y3, X4, RB.activation)) : (ΔY3 = GaLUgrad(ΔX4, Y3))
151151
ΔX3 = conv(ΔY3, RB.W3.data, cdims3)
152152
ΔW3 = ∇conv_filter(ΔY3, RB.activation.forward(Y2), cdims3)
153153

154-
ΔY2 = RB.activation.backward(ΔX3, Y2)
154+
ΔY2 = backward(ΔX3, Y2, X3, RB.activation)
155155
ΔX2 = ∇conv_data(ΔY2, RB.W2.data, cdims2) + ΔY2
156156
ΔW2 = ∇conv_filter(RB.activation.forward(Y1), ΔY2, cdims2)
157157
Δb2 = sum(ΔY2, dims=dims)[inds...]
158158

159159
cdims1 = DenseConvDims(X1, RB.W1.data; stride=RB.strides[1], padding=RB.pad[1])
160160

161-
ΔY1 = RB.activation.backward(ΔX2, Y1)
161+
ΔY1 = backward(ΔX2, Y1, X2, RB.activation)
162162
ΔX1 = ∇conv_data(ΔY1, RB.W1.data, cdims1)
163163
ΔW1 = ∇conv_filter(X1, ΔY1, cdims1)
164164
Δb1 = sum(ΔY1, dims=dims)[inds...]
@@ -187,21 +187,21 @@ function jacobian(ΔX1::AbstractArray{T, N}, Δθ::Array{Parameter, 1},
187187
Y1 = conv(X1, RB.W1.data, cdims1) .+ reshape(RB.b1.data, inds...)
188188
ΔY1 = conv(ΔX1, RB.W1.data, cdims1) + conv(X1, Δθ[1].data, cdims1) .+ reshape(Δθ[4].data, inds...)
189189
X2 = RB.activation.forward(Y1)
190-
ΔX2 = RB.activation.backward(ΔY1, Y1)
190+
ΔX2 = backward(ΔY1, Y1, X2, RB.activation)
191191

192192
cdims2 = DenseConvDims(X2, RB.W2.data; stride=RB.strides[2], padding=RB.pad[2])
193193

194194
Y2 = X2 + conv(X2, RB.W2.data, cdims2) .+ reshape(RB.b2.data, inds...)
195195
ΔY2 = ΔX2 + conv(ΔX2, RB.W2.data, cdims2) + conv(X2, Δθ[2].data, cdims2) .+ reshape(Δθ[5].data, inds...)
196196
X3 = RB.activation.forward(Y2)
197-
ΔX3 = RB.activation.backward(ΔY2, Y2)
197+
ΔX3 = backward(ΔY2, Y2, X3, RB.activation)
198198

199199
cdims3 = DCDims(X1, RB.W3.data; nc=2*size(X1, N-1), stride=RB.strides[1], padding=RB.pad[1])
200200
Y3 = ∇conv_data(X3, RB.W3.data, cdims3)
201201
ΔY3 = ∇conv_data(ΔX3, RB.W3.data, cdims3) + ∇conv_data(X3, Δθ[3].data, cdims3)
202202
if RB.fan == true
203203
X4 = RB.activation.forward(Y3)
204-
ΔX4 = RB.activation.backward(ΔY3, Y3)
204+
ΔX4 = backward(ΔY3, Y3, X4, RB.activation)
205205
else
206206
ΔX4, X4 = GaLUjacobian(ΔY3, Y3)
207207
end

src/utils/activation_functions.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,18 @@ struct ActivationFunction
1919
backward::Function
2020
end
2121

22+
function backward(Δy::AbstractArray{T, N}, x::AbstractArray{T, N}, y::AbstractArray{T, N}, activation::ActivationFunction) where {T, N}
23+
backward_activation(activation.backward, activation.inverse, Δy, x, y)
24+
end
25+
26+
function backward_activation(back::Function, inverse::Nothing, Δy::AbstractArray{T, N}, x::AbstractArray{T, N}, y::AbstractArray{T, N}) where {T, N}
27+
back(Δy, x)
28+
end
29+
30+
function backward_activation(back::Function, inverse::Function, Δy::AbstractArray{T, N}, x::AbstractArray{T, N}, y::AbstractArray{T, N}) where {T, N}
31+
back(Δy, y)
32+
end
33+
2234
function ReLUlayer()
2335
return ActivationFunction(ReLU, nothing, ReLUgrad)
2436
end
@@ -46,7 +58,7 @@ function GaLUlayer()
4658
end
4759

4860
function ExpClampLayer()
49-
return ActivationFunction(x -> ExpClamp(x), y -> ExpClampInv(y/2f0), (Δy, y) -> ExpClampGrad(Δy*2f0, y/2f0))
61+
return ActivationFunction(x -> 2 * ExpClamp(x), y -> ExpClampInv(y/2f0), (Δy, y) -> ExpClampGrad(Δy*2f0, y/2f0))
5062
end
5163

5264

test/test_layers/test_coupling_layer_irim.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ X0 = randn(Float32, nx, ny, n_in, batchsize)
2020
dX = X - X0
2121

2222
# Invertible layers
23-
L = CouplingLayerIRIM(n_in, n_hidden)
24-
L01 = CouplingLayerIRIM(n_in, n_hidden)
25-
L02 = CouplingLayerIRIM(n_in, n_hidden)
23+
L = CouplingLayerIRIM(n_in, n_hidden; rb_activation=SigmoidLayer())
24+
L01 = CouplingLayerIRIM(n_in, n_hidden; rb_activation=SigmoidLayer())
25+
L02 = CouplingLayerIRIM(n_in, n_hidden; rb_activation=SigmoidLayer())
2626

2727
###################################################################################################
2828
# Test invertibility
@@ -131,9 +131,9 @@ end
131131
# Gradient test
132132

133133
# Initialization
134-
L = CouplingLayerIRIM(n_in, n_hidden)
134+
L = CouplingLayerIRIM(n_in, n_hidden; rb_activation=SigmoidLayer())
135135
θ = deepcopy(get_params(L))
136-
L0 = CouplingLayerIRIM(n_in, n_hidden)
136+
L0 = CouplingLayerIRIM(n_in, n_hidden; rb_activation=SigmoidLayer())
137137
θ0 = deepcopy(get_params(L0))
138138
X = randn(Float32, nx, ny, n_in, batchsize)
139139

0 commit comments

Comments
 (0)