270270# a * b
271271
272272function reverse_mul! (output, output_deriv, a, b, a_tmp, b_tmp)
273- istracked (a) && increment_deriv! (a, mul! (a_tmp, output_deriv, transpose (value (b))))
273+ if istracked (a)
274+ if a_tmp isa AbstractVector && b isa AbstractMatrix
275+ # this branch is required for scalar-valued functions that
276+ # involve outer-products of vectors, for such functions, the target
277+ # a_temp is a vector, but when b is a matrix, we cannot multiply into a vector,
278+ # so need to reshape memory to look like matrix (see PositiveFactorizations.jl)
279+ increment_deriv! (a, mul! (reshape (a_tmp, :, 1 ), output_deriv, transpose (value (b))))
280+ else
281+ increment_deriv! (a, mul! (a_tmp, output_deriv, transpose (value (b))))
282+ end
283+ end
274284 istracked (b) && increment_deriv! (b, mul! (b_tmp, transpose (value (a)), output_deriv))
275285end
276286
@@ -279,8 +289,14 @@ for (f, F) in ((:transpose, :Transpose), (:adjoint, :Adjoint))
279289 # a * f(b)
280290 function reverse_mul! (output, output_deriv, a, b:: $F , a_tmp, b_tmp)
281291 _b = ($ f)(b)
282- istracked (a) && increment_deriv! (a, mul! (a_tmp, output_deriv, mulargvalue (b)))
283- istracked (_b) && increment_deriv! (_b, ($ f)(mul! (b_tmp, ($ f)(output_deriv), value (a))))
292+ if istracked (a)
293+ if a_tmp isa AbstractVector
294+ increment_deriv! (a, mul! (reshape (a_tmp, :, 1 ), output_deriv, mulargvalue (_b)))
295+ else
296+ increment_deriv! (a, mul! (a_tmp, output_deriv, mulargvalue (b)))
297+ end
298+ end
299+ istracked (_b) && increment_deriv! (_b, ($ f)(mul! (($ f)(b_tmp), ($ f)(output_deriv), value (a))))
284300 end
285301 # f(a) * b
286302 function reverse_mul! (output, output_deriv, a:: $F , b, a_tmp, b_tmp)
0 commit comments