import torch
def inverse_hessian_product(loss_func, x, damping=0.01, iteration=5):
x.requires_grad_(True)
# Calculate the gradient of the loss with respect to the training data
loss = loss_func(x)
grad = torch.autograd.grad(loss, x, create_graph=True)[0]
grad_norm = torch.norm(grad)
cur_estimate = grad.clone()
# Inverse Hessian product Update: gradient + (I - Hessian_at_x) * cur_estimate, where the cur_estimate is initialized as gradient
for i in range(iteration):
# Hessian * gradient
_, hvp = torch.autograd.functional.hvp(loss_func, x, cur_estimate)
hvp_norm = torch.norm(hvp, p='fro')
cur_estimate = grad + (1-damping) * cur_estimate - (hvp / hvp_norm) * grad_norm
return cur_estimate