# Tensorflow: compute Hessian with respect to each sample

I have a tensor X of size M x D. We can interpret each row of X as a training sample and each column as a feature.

X is used to compute a tensor u of size M x 1 (in other words, u depends on X in the computational graph). We can interpret this as a vector of predictions; one for each sample. In particular, the m-th row of u is computed using only the m-th row of X.

Now, if I run tensor.gradients(u, X)[0], I obtain an M x D tensor corresponding to the "per-sample" gradient of u with respect to X.

How can I similarly compute the "per-sample" Hessian tensor? (i.e., an M x D x D quantity)

Addendum: Peter's answer below is correct. I also found a different approach using stacking and unstacking (using Peter's notation):

hess2 = tf.stack([
tf.gradients( tmp, a )[ 0 ]
for tmp in tf.unstack( grad, num=5, axis=1 )
], axis = 2)


In Peter's example, D=5 is the number of features. I suspect (but I have not checked) that The above is faster for M large, as it skips over the zero entries mentioned in Peter's answer.

## 评论

### tf.hessians() is calculating

tf.hessians() is calculating the Hessian for the provided ys and xs reagardless of the dimensions. Since you have a result of dimension M x D and xs of dimension M x D therefore the result will be of dimension M x D x M x D. But since the outputs per exemplar are independent of each other, most of the Hessian will be zero, namely only one slice in the third dimension will have any value whatsoever. Therefore to get to your desired result, you should take the diagonal in the two M dimensions, or much easier, you should simply sum and eliminate the third dimension like so:

hess2 = tf.reduce_sum( hess, axis = 2 )


Example code (tested):

import tensorflow as tf

a = tf.constant( [ [ 1.0, 1, 1, 1, 1 ], [ 2, 2, 2, 2, 2 ], [ 3, 3, 3, 3, 3 ] ] )
b = tf.constant( [ [ 1.0 ], [ 2 ], [ 3 ], [ 4 ], [ 5 ] ] )
c = tf.matmul( a, b )
c_sq = tf.square( c )

hess = tf.hessians( c_sq, a )[ 0 ]
hess2 = tf.reduce_sum( hess, axis = 2 )

with tf.Session() as sess:
res = sess.run( [ c_sq, grad, hess2 ] )

for v in res:
print( v.shape )
print( v )
print( "=======================")


will output:

(3, 1)
[[ 225.]
[ 900.]
[2025.]]
=======================
(3, 5)
[[ 30. 60. 90. 120. 150.]
[ 60. 120. 180. 240. 300.]
[ 90. 180. 270. 360. 450.]]
=======================
(3, 5, 5)
[[[ 2. 4. 6. 8. 10.]
[ 4. 8. 12. 16. 20.]
[ 6. 12. 18. 24. 30.]
[ 8. 16. 24. 32. 40.]
[10. 20. 30. 40. 50.]]

[[ 2. 4. 6. 8. 10.]
[ 4. 8. 12. 16. 20.]
[ 6. 12. 18. 24. 30.]
[ 8. 16. 24. 32. 40.]
[10. 20. 30. 40. 50.]]

[[ 2. 4. 6. 8. 10.]
[ 4. 8. 12. 16. 20.]
[ 6. 12. 18. 24. 30.]
[ 8. 16. 24. 32. 40.]
[10. 20. 30. 40. 50.]]]
=======================