Skip to main content

example_introduction

! pip install -q torchview
! pip install -q -U graphviz
from torchview import draw_graph
from torch import nn
import torch
import graphviz

# when running on VSCode run the below command
# svg format on vscode does not give desired result
graphviz.set_jupyter_format('png')
{"type":"string"}

The purpose of this notebook is to introduce API and notation of torchview package with common use cases.

We start with simple MLP model

class MLP(nn.Module):
"""Multi Layer Perceptron with inplace option.
Make sure inplace=true and false has the same visual graph"""

def __init__(self, inplace: bool = True) -> None:
super().__init__()
self.layers = nn.Sequential(
nn.Linear(128, 128),
nn.ReLU(inplace),
nn.Linear(128, 128),
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.layers(x)
return x
model_graph_1 = draw_graph(
MLP(), input_size=(2, 128),
graph_name='MLP',
hide_inner_tensors=False,
hide_module_functions=False,
)
model_graph_1.visual_graph

Any visual graph representation of pytorch models provided by torchview package consists of nodes and directed edges (maybe also undirected ones for future releases). Each node is connected by an edge that indicates information flow in the neural network.

There are 3 types of nodes:

  • Tensor Node
  • Function Node
  • Module Node

1) Tensor Node: This node is represented by bright yellow color. It has the label is of the form {tensor-name}{depth}: {tensor-shape}. tensor-name can take 3 values input-tensor, hidden-tensor, or output-tensor. Depth is the depth of tensor in hierarchy of modules.

2) Function Node: This node is represented by bright blue color. Its label is of the form {Function-name}{depth}: {input and output shape}.

3) Module Node: This node is represented by bright green color. Its label is of the form {Module-name}{depth}: {input and output shape}.

In the example of MLP above, input tensor is called by main module MLP. This input tensor is called by its submodules, Sequential. Then, it is called by its submodule linear. Now, inside linear module exists linear function F.linear. This finally applied to input-tensor and returns output-tensor. This is sent to ReLU layer and so on.

Now, we show how rolling mechanism on recursive modules implemented. To demonstrate this, we use RNN module

class SimpleRNN(nn.Module):
"""Simple RNN"""

def __init__(self, inplace: bool = True) -> None:
super().__init__()
self.hid_dim = 2
self.input_dim = 3
self.max_length = 4
self.lstm = nn.LSTMCell(self.input_dim, self.hid_dim)
self.activation = nn.LeakyReLU(inplace=inplace)
self.projection = nn.Linear(self.hid_dim, self.input_dim)

def forward(self, token_embedding: torch.Tensor) -> torch.Tensor:
b_size = token_embedding.size()[0]
hx = torch.randn(b_size, self.hid_dim, device=token_embedding.device)
cx = torch.randn(b_size, self.hid_dim, device=token_embedding.device)

for _ in range(self.max_length):
hx, cx = self.lstm(token_embedding, (hx, cx))
hx = self.activation(hx)

return hx
model_graph_2 = draw_graph(
SimpleRNN(), input_size=(2, 3),
graph_name='RecursiveNet',
roll=True
)
model_graph_2.visual_graph

In the graph above, we see a rolled representation of RNN with LSTM units. We see that LSTMCell and LeakyReLU nodes. This is representated by the numbers show on edges. These number near edges represent the number of edges that occur in forward prop. For instance, the first number 4 represent the number of times token_embedding is used.

If the number of times that edge is used is 1, then it is not shown.

Another useful feature is the resize feature. Say, the previous image of RNN is too big for your purpose. What we can do is to use resize feature rescale it by 0.5

model_graph_2.resize_graph(scale=0.5)
model_graph_2.visual_graph

It got smaller !!!