From dbfb016793399c3bed301a48f3d1e7fb62461b21 Mon Sep 17 00:00:00 2001 From: Diego Canez Date: Tue, 16 Apr 2024 15:44:56 +0200 Subject: [PATCH] feat: add support for weight_decay --- pyhessian/hessian.py | 7 ++++--- pyhessian/utils.py | 23 ++++++++++++++--------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/pyhessian/hessian.py b/pyhessian/hessian.py index 9a9ec1b..5f6a467 100644 --- a/pyhessian/hessian.py +++ b/pyhessian/hessian.py @@ -34,7 +34,7 @@ class hessian(): iii) the estimated eigenvalue density """ - def __init__(self, model, criterion, data=None, dataloader=None, cuda=True): + def __init__(self, model, criterion, data=None, dataloader=None, weight_decay=None, cuda=True): """ model: the model that needs Hessain information criterion: the loss function @@ -74,7 +74,8 @@ def __init__(self, model, criterion, data=None, dataloader=None, cuda=True): loss.backward(create_graph=True) # this step is used to extract the parameters from the model - params, gradsH = get_params_grad(self.model) + self.weight_decay = weight_decay + params, gradsH = get_params_grad(self.model, weight_decay=self.weight_decay) self.params = params self.gradsH = gradsH # gradient used for Hessian computation @@ -91,7 +92,7 @@ def dataloader_hv_product(self, v): outputs = self.model(inputs.to(device)) loss = self.criterion(outputs, targets.to(device)) loss.backward(create_graph=True) - params, gradsH = get_params_grad(self.model) + params, gradsH = get_params_grad(self.model, weight_decay=self.weight_decay) self.model.zero_grad() Hv = torch.autograd.grad(gradsH, params, diff --git a/pyhessian/utils.py b/pyhessian/utils.py index 4ab7942..f24fca7 100644 --- a/pyhessian/utils.py +++ b/pyhessian/utils.py @@ -1,4 +1,4 @@ -#* +# * # @file Different utility functions # Copyright (c) Zhewei Yao, Amir Gholami # All rights reserved. @@ -16,12 +16,13 @@ # # You should have received a copy of the GNU General Public License # along with PyHessian. If not, see . -#* +# * import torch import math from torch.autograd import Variable import numpy as np +from typing import Optional def group_product(xs, ys): @@ -58,7 +59,7 @@ def normalization(v): return v -def get_params_grad(model): +def get_params_grad(model: torch.nn.Module, weight_decay: Optional[float] = None): """ get model parameters and corresponding gradients """ @@ -68,7 +69,13 @@ def get_params_grad(model): if not param.requires_grad: continue params.append(param) - grads.append(0. if param.grad is None else param.grad + 0.) + if param.grad is not None: + grad = param.grad + if weight_decay is not None: + grad = grad.add(param, alpha=weight_decay) + grads.append(grad) + else: + grads.append(0.) return params, grads @@ -79,11 +86,9 @@ def hessian_vector_product(gradsH, params, v): params is the corresponding variables, v is the vector. """ - hv = torch.autograd.grad(gradsH, - params, - grad_outputs=v, - only_inputs=True, - retain_graph=True) + hv = torch.autograd.grad( + gradsH, params, grad_outputs=v, only_inputs=True, retain_graph=True + ) return hv