getitem方法:
class MyList:
def __init__(self):
self.data = [10, 20, 30, 40, 50]
def __getitem__(self, idx):
return self.data[idx]
my_list_obj = MyList()
print(my_list_obj[2])
len方法:
class MyList:
def __init__(self):
self.data = [10, 20, 30, 40, 50]
def __len__(self):
return len(self.data)
my_list_obj = MyList()
print(len(my_list_obj))
hook函数:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(42)
np.random.seed(42)
张量钩子:
x = torch.tensor([2.0], requires_grad=True)
y = x ** 2
z = y ** 3
def tensor_hook(grad):
print(f"原始梯度: {grad}")
return grad / 2
hook_handle = y.register_hook(tensor_hook)
z.backward()
print(f"x的梯度: {x.grad}")
hook_handle.remove()@浙大疏锦行