ELI5 How to use PyTorch

Published: Updated:
Published: Updated:

One morning I see in my timeline this tweet. It exlaims about transformer model abilities

NN learns how to learn linear regression, decision trees, 2-layer ReLU nets 😲 furthermore: outperforms XGBoost, does Lasso in one-pass, seems not to rely on nearest-neighbor.

It refers to this work. I look carefully through the article. The example looks simple, and I want to play with linear approximation and find its limitation. Good thing they published model and training scripts.

At work we recently deployed POS (point of sale) software written in Python. Web server, DB connector, abstract classes, function decorators. It is great. Python is great. But when I read implementation of a ML algorithm from this paper I’m starting to hate Python.


To go through evaluation process you need to run jupyter notebook

It will be handy if CUDA 11.3 is already installed. I wrote about it earlier. Then to install PyTorch you will need to run

pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113

Also the following modules will be required:

  • transformers
  • sklearn
  • numpy - vectors and matrices (@ operator which is a short for matmul)
  • xgboost
  • munch - config reading
  • tqdm - fancy progressbar for terminal

Understand Models in PyTorch

The core of any neural network model (and apparently Transformer as well) is Module

What can you say about it's trivial example?

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Clearly it has something to do with 2D convolutions that process 1 channel on the first layer and 20 channels on second. Channels? Easy peasy. Kernel size is 5. Signal between layers simplified by rectified linear unit (ReLU).

Still nothing makes sense? Okay, let me tell this one more time. Given two dimensional data (in the case of stock price graph it's stock price VS closing time) we want to catch patterns and correlations and apply them in the future. This data is on the first layer. Then

Rate this page