Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 29 additions & 25 deletions torchstat/compute_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def compute_memory(module, inp, out):
return compute_Pool2d_memory(module, inp, out)
else:
print(f"[Memory]: {type(module).__name__} is not supported!")
return (0, 0)
return 0, 0
pass


Expand All @@ -28,20 +28,21 @@ def num_params(module):

def compute_ReLU_memory(module, inp, out):
assert isinstance(module, (nn.ReLU, nn.ReLU6, nn.ELU, nn.LeakyReLU))
batch_size = inp.size()[0]
mread = batch_size * inp.size()[1:].numel()
mwrite = batch_size * inp.size()[1:].numel()

return (mread, mwrite)
mread = inp.numel()
mwrite = out.numel()

return mread, mwrite


def compute_PReLU_memory(module, inp, out):
assert isinstance(module, (nn.PReLU))
assert isinstance(module, nn.PReLU)

batch_size = inp.size()[0]
mread = batch_size * (inp.size()[1:].numel() + num_params(module))
mwrite = batch_size * inp.size()[1:].numel()
mread = batch_size * (inp[0].numel() + num_params(module))
mwrite = out.numel()

return (mread, mwrite)
return mread, mwrite


def compute_Conv2d_memory(module, inp, out):
Expand All @@ -50,39 +51,42 @@ def compute_Conv2d_memory(module, inp, out):
assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())

batch_size = inp.size()[0]
in_c = inp.size()[1]
out_c, out_h, out_w = out.size()[1:]

# This includes weighs with bias if the module contains it.
mread = batch_size * (inp.size()[1:].numel() + num_params(module))
mwrite = batch_size * out_c * out_h * out_w
return (mread, mwrite)
# This includes weights with bias if the module contains it.
mread = batch_size * (inp[0].numel() + num_params(module))
mwrite = out.numel()
return mread, mwrite


def compute_BatchNorm2d_memory(module, inp, out):
assert isinstance(module, nn.BatchNorm2d)
assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())

batch_size, in_c, in_h, in_w = inp.size()
mread = batch_size * (inp[0].numel() + 2 * in_c)
mwrite = out.numel()

mread = batch_size * (inp.size()[1:].numel() + 2 * in_c)
mwrite = inp.size().numel()
return (mread, mwrite)
return mread, mwrite


def compute_Linear_memory(module, inp, out):
assert isinstance(module, nn.Linear)
assert len(inp.size()) == 2 and len(out.size()) == 2

batch_size = inp.size()[0]
mread = batch_size * (inp.size()[1:].numel() + num_params(module))
mwrite = out.size().numel()

return (mread, mwrite)
# This includes weights with bias if the module contains it.
mread = batch_size * (inp[0].numel() + num_params(module))
mwrite = out.numel()

return mread, mwrite


def compute_Pool2d_memory(module, inp, out):
assert isinstance(module, (nn.MaxPool2d, nn.AvgPool2d))
assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
batch_size = inp.size()[0]
mread = batch_size * inp.size()[1:].numel()
mwrite = batch_size * out.size()[1:].numel()
return (mread, mwrite)

mread = inp.numel()
mwrite = out.numel()

return mread, mwrite