Skip to main content

torch_function

This document explains some important point about __torch_function__ and its role in this project.

  • __torch_function__ API is first introduced in v1.5.0, here

  • Subclass preservations of some important operators introduced here. This property is very important since torchview package keeps track tensor of only RecorderTensor subclass.

  • Some important fixes introdued here. For instance support for F.embedding is included. Otherwise, F.embedding under __torch_function__ of subclass would return NotImplemented, leading to torch.Tensor, which is not desired. To prevent this issue (and support pytorch version < 1.9), we added the below code in recorder_tensor.py
        # This is necessary for torch version < 1.10
if func in [F.linear, F.embedding]:
out = nn.parameter.Parameter.__torch_function__(
func, types, args, kwargs).as_subclass(RecorderTensor)
else:
# use original torch_function; otherwise,
# it leads to infinite recursive call of torch_function
out = super().__torch_function__(func, types, args, kwargs)

To be precise about the versions,

F.linear returns `NotImplemented` for versions 1.7.1, 1.8, 1.9 
F.embedding returns `NotImplemented` for versions 1.7.1, 1.8, 1.9

This package does not support torch version <= 1.6 since torch.Tensor does not have __torch_function__ as class methods in these version.

  • Some other relevant PRs: PR1, PR2

Meta tensor related links