pytorch首战——线性网络及其实现

运用pytorch构建一个简单的线性神经网络

本文记录了一些pytorch的基本操作,实现了一个简答的线性回归,一个softmax分类和一个多层感知机(参照《动手学深度学习pytorch版)

一.简单线性回归

通过此练习训练一个能回归到 $$y=x_1w_1+x_2w_2+b$$ 的线性神经网络

1.数据集的生成

1
2
3
4
5
6
7
8
num_inputs = 2
num_examples = 1000
true_w = [2, -3.4]
true_b = 4.2
features = torch.randn(num_examples, num_inputs, dtype=torch.float32) # 直接生成相应的1000*2的张量,并规定数据类型
labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b#这里用[:]通配指定与每一维相乘得到的还是一个1*1000的张量
labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float32) #加上高斯分布的干扰

2.数据集的读取

1
2
3
4
5
6
7
8
9
10
11
def data_iter(batch_size,features,labels):
#对整个数据集生成一个打乱的下标组,这样才能够产生一个随机的batch
num_examples = len(features)#对张量对象可以用len()方法
indices = list(range(num_examples))
#打乱下表数组
random.shuffle(indices)
#开始抽签
for i in range(0,num_examples,batch_size):
j = torch.LongTensor(indices[i:min(i+batch_size,num_examples)])#要取的下标序列,注意最后一个可能不足一个batch的元素个数
yield features.index_select(0,j),labels.index_select(0,j)#实际上是一个生成器,可以源源不断的产生数据,index_select方法中0表示按行索引,1表示按列索引

当然,我们也可以用pytorch中封装好的方法来直接进行实现(工具越强大,人就越懒):

1
2
3
4
5
import torch.utils.data as Data
batch_size = 10
dataset = Data.TensorDataset(features,labels)#组合定义好的特征与标签
#随机读取小批量
data_iter = Data.DataLoader(dataset,batch_size,shuffle = True);

3.定义模型

我们导入pytorch中的模块torch.nn,常用的做法是继承nn.Module,编写自己的网络和层

1
2
3
4
5
6
7
8
9
10
class LinearNet(nn.Module):
def __init__(self,n_feature):
super(LinearNet,self).__init__()#用super()函数调用父类的构造函数
self.linear = nn.Linear(n_feature,1)#in_feature和out_feature
def forward(self,x):#定义前向传播
y=self.linear(x)
return y
net = LinearNet(num_inputs)


在这里我们首先要定义我们网络的初始化参数,即Linear(in_features,out_features,weight,bias);
然后要定义前向传播函数,这个前向传播函数实际上就是给一个输入然后返回输出,在之后求梯度的时候会有用到。

4.初始化模型参数

通过导入init模块,用init.normal_将权重参数每个元素初始化为随机采样与均值为0,标准差为0.01的正态分布;通过init.constant()将bias初始化为常数0

1
2
3
from torch.nn import init
init.normal_(net.linear.weight,mean=0,std=0.01);
init.constant_(net.linear.bias,val = 0)

5.定义损失函数

1
loss = nn.MSEloss()#均方误差

本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!