Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions pyhessian/hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class hessian():
iii) the estimated eigenvalue density
"""

def __init__(self, model, criterion, data=None, dataloader=None, cuda=True):
def __init__(self, model, criterion, data=None, dataloader=None, weight_decay=None, cuda=True):
"""
model: the model that needs Hessain information
criterion: the loss function
Expand Down Expand Up @@ -74,7 +74,8 @@ def __init__(self, model, criterion, data=None, dataloader=None, cuda=True):
loss.backward(create_graph=True)

# this step is used to extract the parameters from the model
params, gradsH = get_params_grad(self.model)
self.weight_decay = weight_decay
params, gradsH = get_params_grad(self.model, weight_decay=self.weight_decay)
self.params = params
self.gradsH = gradsH # gradient used for Hessian computation

Expand All @@ -91,7 +92,7 @@ def dataloader_hv_product(self, v):
outputs = self.model(inputs.to(device))
loss = self.criterion(outputs, targets.to(device))
loss.backward(create_graph=True)
params, gradsH = get_params_grad(self.model)
params, gradsH = get_params_grad(self.model, weight_decay=self.weight_decay)
self.model.zero_grad()
Hv = torch.autograd.grad(gradsH,
params,
Expand Down
23 changes: 14 additions & 9 deletions pyhessian/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#*
# *
# @file Different utility functions
# Copyright (c) Zhewei Yao, Amir Gholami
# All rights reserved.
Expand All @@ -16,12 +16,13 @@
#
# You should have received a copy of the GNU General Public License
# along with PyHessian. If not, see <http://www.gnu.org/licenses/>.
#*
# *

import torch
import math
from torch.autograd import Variable
import numpy as np
from typing import Optional


def group_product(xs, ys):
Expand Down Expand Up @@ -58,7 +59,7 @@ def normalization(v):
return v


def get_params_grad(model):
def get_params_grad(model: torch.nn.Module, weight_decay: Optional[float] = None):
"""
get model parameters and corresponding gradients
"""
Expand All @@ -68,7 +69,13 @@ def get_params_grad(model):
if not param.requires_grad:
continue
params.append(param)
grads.append(0. if param.grad is None else param.grad + 0.)
if param.grad is not None:
grad = param.grad
if weight_decay is not None:
grad = grad.add(param, alpha=weight_decay)
grads.append(grad)
else:
grads.append(0.)
return params, grads


Expand All @@ -79,11 +86,9 @@ def hessian_vector_product(gradsH, params, v):
params is the corresponding variables,
v is the vector.
"""
hv = torch.autograd.grad(gradsH,
params,
grad_outputs=v,
only_inputs=True,
retain_graph=True)
hv = torch.autograd.grad(
gradsH, params, grad_outputs=v, only_inputs=True, retain_graph=True
)
return hv


Expand Down