Is there any problem in my code? #18673
nobuo-toyama
started this conversation in
General
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
import random
import torch
import pytorch_lightning as pl
from pytorch_lightning import LightningDataModule
from pytorch_lightning import Trainer
def add_to_class(Class):
"""作成されたクラスにメソッドを追加する。"""
def wrapper(obj):
setattr(Class, obj.name, obj)
return wrapper
class A:
def init(self):
self.b = 1
self.do()
a = A()
def save_hyperparameters(self, ignore=[]):
raise NotImplementedError
class HyperParameters:
def init(self):
self.hyperparameters = {}
class B(HyperParameters):
def init(self, a, b):
super().init()
self.a = a
self.b = b
b = B(a=1, b=2)
print(b.hyperparameters)
class SyntheticRegressionData(pl.LightningDataModule):
def init(self, w, b, noise=0.01, num_train=1000, num_val=1000, batch_size=32):
super().init()
self.save_hyperparameters()
self.num_train = num_train
self.num_val = num_val
self.batch_size = batch_size
n = num_train + num_val
self.X = torch.randn(n, len(w))
noise = torch.randn(n, 1) * noise
self.y = torch.matmul(self.X, w.reshape((-1, 1))) + b + noise
data = SyntheticRegressionData(w=torch.tensor([2, 3.5]), b=3)
add_to_class(SyntheticRegressionData)
def get_dataloader(self, train):
if train:
indices = list(range(0, self.num_train))
# The examples are read in random order
random.shuffle(indices)
else:
indices = list(range(self.num_train, self.num_train+self.num_val))
for i in range(0, len(indices), self.batch_size):
batch_indices = torch.tensor(indices[i: i+self.batch_size])
yield self.X[batch_indices], self.y[batch_indices]
データモジュールのインスタンス化
data = SyntheticRegressionData(w=torch.tensor([2, 3.5]), b=3)
データローダーからデータを取得して表示
X, y = next(iter(data.get_dataloader(train=True)))
print('X shape:', X.shape, '\ny shape:', y.shape)
add_to_class(LightningDataModule)
def get_tensorloader(self, tensors, train, indices=slice(0, None)):
tensors = tuple(a[indices] for a in tensors)
dataset = torch.utils.data.TensorDataset(*tensors)
return torch.utils.data.DataLoader(dataset, self.batch_size,
shuffle=train)
add_to_class(SyntheticRegressionData)
def get_dataloader(self, train):
i = slice(0, self.num_train) if train else slice(self.num_train, None)
return self.get_tensorloader((self.X, self.y), train, i)
X, y = next(iter(data.train_dataloader()))
print('X shape:', X.shape, '\ny shape:', y.shape)
class LinearRegressionScratch(LightningDataModule):
"""ゼロから実装する線形回帰モデル"""
def init(self, num_inputs, lr, sigma=0.01):
super().init()
self.save_hyperparameters()
self.w = torch.normal(0, sigma, (num_inputs, 1), requires_grad=True)
self.b = torch.zeros(1, requires_grad=True)
add_to_class(LinearRegressionScratch)
def forward(self, X):
return torch.matmul(X, self.w) + self.b
add_to_class(LinearRegressionScratch)
def loss(self, y_hat, y):
l = (y_hat - y) ** 2 / 2
return l.mean()
class SGD(HyperParameters):
"""ミニバッチ勾配降下法"""
def init(self, params, lr):
self.save_hyperparameters()
add_to_class(LinearRegressionScratch)
def configure_optimizers(self):
return SGD([self.w, self.b], self.lr)
Trainer = pl.Trainer()
add_to_class(Trainer)
def prepare_batch(self, batch):
return batch
add_to_class(Trainer)
def fit_epoch(self):
self.model.train()
for batch in self.train_dataloader:
loss = self.model.training_step(self.prepare_batch(batch))
self.optim.zero_grad()
with torch.no_grad():
loss.backward()
if self.gradient_clip_val > 0: # To be discussed later
self.clip_gradients(self.gradient_clip_val, self.model)
self.optim.step()
self.train_batch_idx += 1
if self.val_dataloader is None:
return
self.model.eval()
for batch in self.val_dataloader:
with torch.no_grad():
self.model.validation_step(self.prepare_batch(batch))
self.val_batch_idx += 1
model = LinearRegressionScratch(2, lr=0.03)
data = SyntheticRegressionData(w=torch.tensor([2, 3.5]), b=3)
trainer = Trainer(max_epochs=3)
trainer.fit(model, data)
TypeError Traceback (most recent call last)
Cell In[21], line 3
1 model = LinearRegressionScratch(2, lr=0.03)
2 data = SyntheticRegressionData(w=torch.tensor([2, 3.5]), b=3)
----> 3 trainer = Trainer(max_epochs=3)
4 trainer.fit(model, data)
TypeError: 'Trainer' object is not callable
Beta Was this translation helpful? Give feedback.
All reactions