Skip to content

Updated code for Inverse Hessian Product  #1

@lindsey98

Description

@lindsey98
import torch

def inverse_hessian_product(loss_func, x, damping=0.01, iteration=5):

    x.requires_grad_(True)
    # Calculate the gradient of the loss with respect to the training data
    loss = loss_func(x)
    grad = torch.autograd.grad(loss, x, create_graph=True)[0]
    grad_norm = torch.norm(grad)
    cur_estimate = grad.clone()

    # Inverse Hessian product Update: gradient + (I - Hessian_at_x) * cur_estimate, where the cur_estimate is initialized as gradient
    for i in range(iteration):
        # Hessian * gradient
        _, hvp = torch.autograd.functional.hvp(loss_func, x, cur_estimate)
        hvp_norm = torch.norm(hvp, p='fro')
        cur_estimate = grad + (1-damping) * cur_estimate - (hvp / hvp_norm) * grad_norm

    return cur_estimate

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions