我們需要做的第⼀件事情是獲取 MNIST 數據。如果你是⼀個 git ⽤⼾,那么你能夠通過克隆這本書的代碼倉庫獲得數據,實現我們的⽹絡來分類數字
git clone https://github.com/mnielsen/neural-networks-and-deep-learning.git
class Network(object):def __init__(self, sizes):self.num_layers = len(sizes)self.sizes = sizesself.biases = [np.random.randn(y, 1) for y in sizes[1:]]self.weights = [np.random.randn(y, x)for x, y in zip(sizes[:-1], sizes[1:])]
在這段代碼中,列表 sizes 包含各層神經元的數量。例如,如果我們想創建⼀個在第⼀層有2 個神經元,第⼆層有 3 個神經元,最后層有 1 個神經元的 Network 對象,我們應這樣寫代碼:
net = Network([2, 3, 1])
Network 對象中的偏置和權重都是被隨機初始化的,使⽤ Numpy 的 np.random.randn 函數來⽣成均值為 0,標準差為 1 的⾼斯分布。這樣的隨機初始化給了我們的隨機梯度下降算法⼀個起點。在后⾯的章節中我們將會發現更好的初始化權重和偏置的⽅法,但是⽬前隨機地將其初始化。注意 Network 初始化代碼假設第⼀層神經元是⼀個輸⼊層,并對這些神經元不設置任何偏置,因為偏置僅在后⾯的層中⽤于計算輸出。有了這些,很容易寫出從⼀個 Network 實例計算輸出的代碼。我們從定義 S 型函數開始:
def sigmoid(z):return 1.0/(1.0+np.exp(-z))
注意,當輸⼊ z 是⼀個向量或者 Numpy 數組時,Numpy ⾃動地按元素應⽤ sigmoid 函數,即以向量形式。
我們然后對 Network 類添加⼀個 feedforward ⽅法,對于⽹絡給定⼀個輸⼊ a,返回對應的輸出 6 。這個⽅法所做的是對每⼀層應⽤⽅程 (22):
def feedforward(self, a):"""Return the output of the network if "a" is input."""for b, w in zip(self.biases, self.weights):a = sigmoid(np.dot(w, a)+b)return a
當然,我們想要 Network 對象做的主要事情是學習。為此我們給它們⼀個實現隨即梯度下降算法的 SGD ⽅法。代碼如下。其中⼀些地⽅看似有⼀點神秘,我會在代碼后⾯逐個分析
def SGD(self, training_data, epochs, mini_batch_size, eta,test_data=None):"""Train the neural network using mini-batch stochasticgradient descent. The "training_data" is a list of tuples"(x, y)" representing the training inputs and the desiredoutputs. The other non-optional parameters areself-explanatory. If "test_data" is provided then thenetwork will be evaluated against the test data after eachepoch, and partial progress printed out. This is useful fortracking progress, but slows things down substantially."""if test_data: n_test = len(test_data)n = len(training_data)for j in xrange(epochs):random.shuffle(training_data)mini_batches = [training_data[k:k+mini_batch_size]for k in xrange(0, n, mini_batch_size)]for mini_batch in mini_batches:self.update_mini_batch(mini_batch, eta)if test_data:print "Epoch {0}: {1} / {2}".format(j, self.evaluate(test_data), n_test)else:print "Epoch {0} complete".format(j)
新聞熱點
疑難解答