Skip to content

use torch-mlir compile train Lenet5 report error #1711

Open
@LiqinWeng

Description

@LiqinWeng

source code: lenet_training.py

import torch.nn as nn
from collections import OrderedDict
import torch.optim as optim

class C1(nn.Module):
    def __init__(self):
        super(C1, self).__init__()

        self.c1 = nn.Sequential(OrderedDict([
            ('c1', nn.Conv2d(1, 6, kernel_size=(5, 5))),
            ('relu1', nn.ReLU()),
            ('s1', nn.MaxPool2d(kernel_size=(2, 2), stride=2))
        ]))

    def forward(self, img):
        output = self.c1(img)
        return output

class C2(nn.Module):
    def __init__(self):
        super(C2, self).__init__()

        self.c2 = nn.Sequential(OrderedDict([
            ('c2', nn.Conv2d(6, 16, kernel_size=(5, 5))),
            ('relu2', nn.ReLU()),
            ('s2', nn.MaxPool2d(kernel_size=(2, 2), stride=2))
        ]))

    def forward(self, img):
        output = self.c2(img)
        return output

class C3(nn.Module):
    def __init__(self):
        super(C3, self).__init__()

        self.c3 = nn.Sequential(OrderedDict([
            ('c3', nn.Conv2d(16, 120, kernel_size=(5, 5))),
            ('relu3', nn.ReLU())
        ]))

    def forward(self, img):
        output = self.c3(img)
        return output

class F4(nn.Module):
    def __init__(self):
        super(F4, self).__init__()

        self.f4 = nn.Sequential(OrderedDict([
            ('f4', nn.Linear(120, 84)),
            ('relu4', nn.ReLU())
        ]))

    def forward(self, img):
        output = self.f4(img)
        return output

class F5(nn.Module):
    def __init__(self):
        super(F5, self).__init__()

        self.f5 = nn.Sequential(OrderedDict([
            ('f5', nn.Linear(84, 10)),
            ('sig5', nn.LogSoftmax(dim=-1))
        ]))

    def forward(self, img):
        output = self.f5(img)
        return output


class LeNet5(nn.Module):
    """
    Input - 1x32x32
    Output - 10
    """
    def __init__(self):
        super(LeNet5, self).__init__()

        self.c1 = C1()
        self.c2_1 = C2() 
        self.c2_2 = C2() 
        self.c3 = C3() 
        self.f4 = F4() 
        self.f5 = F5() 

    def forward(self, img):
        output = self.c1(img)

        x = self.c2_1(output)
        output = self.c2_2(output)

        output += x

        output = self.c3(output)
        output = output.view(img.size(0), -1)
        output = self.f4(output)
        output = self.f5(output)
        return output
    
criterion = nn.CrossEntropyLoss()
net = LeNet5()
net.train()
optimizer = optim.Adam(net.parameters(), lr=2e-3)
def learn(images, labels):
    optimizer.zero_grad()
    output = net(images)
    loss = criterion(output, labels)
    loss.backward()
    optimizer.step()
    return loss
        
from torchvision.datasets.mnist import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch_mlir
from iree.compiler import compile_str
import iree.runtime
import functorch

data_test = MNIST('./data/mnist',
                  train=False,
                  download=True,
                  transform=transforms.Compose([
                      transforms.Resize((32, 32)),
                      transforms.ToTensor()]))

data_test_loader = DataLoader(data_test, batch_size=1024, num_workers=8)

out_mhlo_mlir_path = "./lenet5_mhlo.mlir"

from functorch._src.compile_utils import strip_overloads
for i, (images, labels) in enumerate(data_test_loader):
    #print(learn(images, labels))
    train_args = (images, labels)
    graph = functorch.make_fx(learn)(*train_args)
    strip_overloads(graph)
    # fix the __init__.py: let torch_mlir_compiler can set backend_legal_ops,even the outputType is not torch
    module = torch_mlir.compile(graph, train_args, output_type=torch_mlir.OutputType.MHLO, backend_legal_ops=["torch.aten.t","torch.aten.addmm"])
    with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf:
        outf.write(str(module))

    print(f"MHLO IR of LeNet5 successfully written into {out_mhlo_mlir_path}")
    flatbuffer_blob = compile_str(str(module), target_backends=["llvm-cpu"], input_type="mhlo")
    compiled_model = iree.runtime.load_vm_flatbuffer(flatbuffer_blob, backend="llvm-cpu")
    print("compile done")
    print(compiled_model.forward(images, labels))

command: python lenet_training.py ,will report error is as follows:
image

so I fix the init.py: let torch_mlir_compiler can set backend_legal_ops,even the outputType is not torch
image
then error is as follows:
image

Does anyone know what is causing the problem that occurs?

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions