@@ -96,9 +96,9 @@ function forward(X1::AbstractArray{T, N}, X2::AbstractArray{T, N}, L::CouplingLa
9696 Y2 = S.* X2 + logS_T2
9797
9898 if logdet
99- save ? (return X1, Y2, coupling_logdet_forward (S), S) : (return X1, Y2, coupling_logdet_forward (S))
99+ save ? (return X1, Y2, coupling_logdet_forward (S), logS_T1, S) : (return X1, Y2, coupling_logdet_forward (S))
100100 else
101- save ? (return X1, Y2, S) : (return X1, Y2)
101+ save ? (return X1, Y2, logS_T1, S) : (return X1, Y2)
102102 end
103103end
104104
@@ -112,17 +112,17 @@ function inverse(Y1::AbstractArray{T, N}, Y2::AbstractArray{T, N}, L::CouplingLa
112112 X2 = (Y2 - logS_T2) ./ (S .+ eps (T)) # add epsilon to avoid division by 0
113113
114114 if logdet
115- save == true ? (return Y1, X2, - coupling_logdet_forward (S), S) : (return Y1, X2, - coupling_logdet_forward (S))
115+ save == true ? (return Y1, X2, - coupling_logdet_forward (S), logS_T1, S) : (return Y1, X2, - coupling_logdet_forward (S))
116116 else
117- save == true ? (return Y1, X2, S) : (return Y1, X2)
117+ save == true ? (return Y1, X2, logS_T1, S) : (return Y1, X2)
118118 end
119119end
120120
121121# 2D/3D Backward pass: Input (ΔY, Y), Output (ΔX, X)
122122function backward (ΔY1:: AbstractArray{T, N} , ΔY2:: AbstractArray{T, N} , Y1:: AbstractArray{T, N} , Y2:: AbstractArray{T, N} , L:: CouplingLayerBasic ; set_grad:: Bool = true ) where {T, N}
123123
124124 # Recompute forward state
125- X1, X2, S = inverse (Y1, Y2, L; save= true , logdet= false )
125+ X1, X2, logS_T1, S = inverse (Y1, Y2, L; save= true , logdet= false )
126126
127127 # Backpropagate residual
128128 ΔT = copy (ΔY2)
@@ -132,11 +132,11 @@ function backward(ΔY1::AbstractArray{T, N}, ΔY2::AbstractArray{T, N}, Y1::Abst
132132 end
133133 ΔX2 = ΔY2 .* S
134134 if set_grad
135- ΔX1 = L. RB. backward (tensor_cat (L . activation . backward (ΔS, S ), ΔT), X1) + ΔY1
135+ ΔX1 = L. RB. backward (tensor_cat (backward (ΔS, logS_T1, S, L . activation ), ΔT), X1) + ΔY1
136136 else
137- ΔX1, Δθ = L. RB. backward (tensor_cat (L . activation . backward (ΔS, S ), ΔT), X1; set_grad= set_grad)
137+ ΔX1, Δθ = L. RB. backward (tensor_cat (backward (ΔS, logS_T1, S, L . activation ), ΔT), X1; set_grad= set_grad)
138138 if L. logdet
139- _, ∇logdet = L. RB. backward (tensor_cat (L . activation . backward (coupling_logdet_backward (S), S ), 0 .* ΔT), X1; set_grad= set_grad)
139+ _, ∇logdet = L. RB. backward (tensor_cat (backward (coupling_logdet_backward (S), logS_T1, S, L . activation ), 0 .* ΔT), X1; set_grad= set_grad)
140140 end
141141 ΔX1 += ΔY1
142142 end
152152function backward_inv (ΔX1:: AbstractArray{T, N} , ΔX2:: AbstractArray{T, N} , X1:: AbstractArray{T, N} , X2:: AbstractArray{T, N} , L:: CouplingLayerBasic ; set_grad:: Bool = true ) where {T, N}
153153
154154 # Recompute inverse state
155- Y1, Y2, S = forward (X1, X2, L; save= true , logdet= false )
155+ Y1, Y2, logS_T1, S = forward (X1, X2, L; save= true , logdet= false )
156156
157157 # Backpropagate residual
158158 ΔT = - ΔX2 ./ S
@@ -161,9 +161,9 @@ function backward_inv(ΔX1::AbstractArray{T, N}, ΔX2::AbstractArray{T, N}, X1::
161161 set_grad ? (ΔS += coupling_logdet_backward (S)) : (∇logdet = - coupling_logdet_backward (S))
162162 end
163163 if set_grad
164- ΔY1 = L. RB. backward (tensor_cat (L . activation . backward (ΔS, S ), ΔT), Y1) + ΔX1
164+ ΔY1 = L. RB. backward (tensor_cat (backward (ΔS, logS_T1, S, L . activation ), ΔT), Y1) + ΔX1
165165 else
166- ΔY1, Δθ = L. RB. backward (tensor_cat (L . activation . backward (ΔS, S ), ΔT), Y1; set_grad= set_grad)
166+ ΔY1, Δθ = L. RB. backward (tensor_cat (backward (ΔS, logS_T1, S, L . activation ), ΔT), Y1; set_grad= set_grad)
167167 ΔY1 += ΔX1
168168 end
169169 ΔY2 = - ΔT
@@ -187,14 +187,14 @@ function jacobian(ΔX1::AbstractArray{T, N}, ΔX2::AbstractArray{T, N}, Δθ::Ab
187187 logS_T1, logS_T2 = tensor_split (L. RB. forward (X1))
188188 ΔlogS_T1, ΔlogS_T2 = tensor_split (jacobian (ΔX1, Δθ, X1, L. RB)[1 ])
189189 S = L. activation. forward (logS_T1)
190- ΔS = L . activation . backward (ΔlogS_T1, S )
190+ ΔS = backward (ΔlogS_T1, logS_T1, S, L . activation )
191191 Y2 = S.* X2 + logS_T2
192192 ΔY2 = ΔS.* X2 + S.* ΔX2 + ΔlogS_T2
193193
194194 if logdet
195195 # Gauss-Newton approximation of logdet terms
196196 JΔθ = tensor_split (L. RB. jacobian (zeros (Float32, size (ΔX1)), Δθ, X1)[1 ])[1 ]
197- GNΔθ = - L. RB. adjointJacobian (tensor_cat (L . activation . backward (JΔθ, S ), zeros (Float32, size (S))), X1)[2 ]
197+ GNΔθ = - L. RB. adjointJacobian (tensor_cat (backward (JΔθ, logS_T1, S, L . activation ), zeros (Float32, size (S))), X1)[2 ]
198198
199199 save ? (return ΔX1, ΔY2, X1, Y2, coupling_logdet_forward (S), GNΔθ, S) : (return ΔX1, ΔY2, X1, Y2, coupling_logdet_forward (S), GNΔθ)
200200 else
0 commit comments