Skip to content

Commit c9c62d0

Browse files
committed
Transfer learning mnist
1 parent 4a8c495 commit c9c62d0

File tree

2 files changed

+803
-0
lines changed

2 files changed

+803
-0
lines changed
Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {
7+
"collapsed": false
8+
},
9+
"outputs": [
10+
{
11+
"name": "stderr",
12+
"output_type": "stream",
13+
"text": [
14+
"Using TensorFlow backend.\n"
15+
]
16+
}
17+
],
18+
"source": [
19+
"import numpy as np\n",
20+
"from matplotlib import pyplot as plt\n",
21+
"%matplotlib inline\n",
22+
"\n",
23+
"import keras\n",
24+
"from keras.datasets import mnist, cifar10\n",
25+
"from keras.layers import Dense, Convolution2D, Flatten, Activation, MaxPool2D, Dropout, Flatten\n",
26+
"from keras.models import Sequential\n",
27+
"from keras.utils import np_utils"
28+
]
29+
},
30+
{
31+
"cell_type": "code",
32+
"execution_count": 2,
33+
"metadata": {
34+
"collapsed": true
35+
},
36+
"outputs": [],
37+
"source": [
38+
"# Import MNIST Datasets\n",
39+
"(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
40+
"n_examples = 40000"
41+
]
42+
},
43+
{
44+
"cell_type": "code",
45+
"execution_count": 3,
46+
"metadata": {
47+
"collapsed": false
48+
},
49+
"outputs": [],
50+
"source": [
51+
"X1_train = []\n",
52+
"X1_test = []\n",
53+
"\n",
54+
"X2_train = []\n",
55+
"X2_test = []\n",
56+
"\n",
57+
"Y1_train = []\n",
58+
"Y1_test = []\n",
59+
"\n",
60+
"Y2_train = []\n",
61+
"Y2_test = []\n",
62+
"\n",
63+
"for ix in range(n_examples):\n",
64+
" if y_train[ix] < 5:\n",
65+
" # Put data in set 01\n",
66+
" X1_train.append(x_train[ix]/255.0)\n",
67+
" Y1_train.append(y_train[ix])\n",
68+
" else:\n",
69+
" # Put data in set 02\n",
70+
" X2_train.append(x_train[ix]/255.0)\n",
71+
" Y2_train.append(y_train[ix])\n",
72+
"\n",
73+
"for ix in range(y_test.shape[0]):\n",
74+
" if y_test[ix] < 5:\n",
75+
" # Put data in set 01\n",
76+
" X1_test.append(x_test[ix]/255.0)\n",
77+
" Y1_test.append(y_test[ix])\n",
78+
" else:\n",
79+
" # Put data in set 02\n",
80+
" X2_test.append(x_test[ix]/255.0)\n",
81+
" Y2_test.append(y_test[ix])"
82+
]
83+
},
84+
{
85+
"cell_type": "code",
86+
"execution_count": 4,
87+
"metadata": {
88+
"collapsed": false
89+
},
90+
"outputs": [],
91+
"source": [
92+
"X1_train = np.asarray(X1_train).reshape((-1, 28, 28, 1))\n",
93+
"X1_test = np.asarray(X1_test).reshape((-1, 28, 28, 1))\n",
94+
"\n",
95+
"X2_train = np.asarray(X2_train).reshape((-1, 28, 28, 1))\n",
96+
"X2_test = np.asarray(X2_test).reshape((-1, 28, 28, 1))\n",
97+
"\n",
98+
"Y1_train = np_utils.to_categorical(np.asarray(Y1_train), num_classes=5)\n",
99+
"Y1_test = np_utils.to_categorical(np.asarray(Y1_test), num_classes=5)\n",
100+
"\n",
101+
"Y2_train = np_utils.to_categorical(np.asarray(Y2_train), num_classes=10)\n",
102+
"Y2_test = np_utils.to_categorical(np.asarray(Y2_test), num_classes=10)"
103+
]
104+
},
105+
{
106+
"cell_type": "code",
107+
"execution_count": 5,
108+
"metadata": {
109+
"collapsed": true
110+
},
111+
"outputs": [],
112+
"source": [
113+
"split1 = int(0.8 * X1_train.shape[0])\n",
114+
"split2 = int(0.8 * X2_train.shape[0])\n",
115+
"\n",
116+
"x1_val = X1_train[split1:]\n",
117+
"x1_train = X1_train[:split1]\n",
118+
"y1_val = Y1_train[split1:]\n",
119+
"y1_train = Y1_train[:split1]\n",
120+
"\n",
121+
"x2_val = X2_train[split2:]\n",
122+
"x2_train = X2_train[:split2]\n",
123+
"y2_val = Y2_train[split2:]\n",
124+
"y2_train = Y2_train[:split2]\n"
125+
]
126+
},
127+
{
128+
"cell_type": "code",
129+
"execution_count": 6,
130+
"metadata": {
131+
"collapsed": false
132+
},
133+
"outputs": [
134+
{
135+
"name": "stdout",
136+
"output_type": "stream",
137+
"text": [
138+
"(16336, 28, 28, 1) (5139, 28, 28, 1)\n",
139+
"(20420, 5) (5139, 5)\n",
140+
"(19580, 28, 28, 1) (4861, 28, 28, 1)\n",
141+
"(19580, 10) (4861, 10)\n"
142+
]
143+
}
144+
],
145+
"source": [
146+
"print x1_train.shape, X1_test.shape\n",
147+
"print Y1_train.shape, Y1_test.shape\n",
148+
"\n",
149+
"print X2_train.shape, X2_test.shape\n",
150+
"print Y2_train.shape, Y2_test.shape"
151+
]
152+
},
153+
{
154+
"cell_type": "code",
155+
"execution_count": 7,
156+
"metadata": {
157+
"collapsed": false
158+
},
159+
"outputs": [
160+
{
161+
"name": "stdout",
162+
"output_type": "stream",
163+
"text": [
164+
"_________________________________________________________________\n",
165+
"Layer (type) Output Shape Param # \n",
166+
"=================================================================\n",
167+
"conv2d_1 (Conv2D) (None, 24, 24, 32) 832 \n",
168+
"_________________________________________________________________\n",
169+
"conv2d_2 (Conv2D) (None, 20, 20, 16) 12816 \n",
170+
"_________________________________________________________________\n",
171+
"max_pooling2d_1 (MaxPooling2 (None, 10, 10, 16) 0 \n",
172+
"_________________________________________________________________\n",
173+
"conv2d_3 (Conv2D) (None, 8, 8, 8) 1160 \n",
174+
"_________________________________________________________________\n",
175+
"flatten_1 (Flatten) (None, 512) 0 \n",
176+
"_________________________________________________________________\n",
177+
"dropout_1 (Dropout) (None, 512) 0 \n",
178+
"_________________________________________________________________\n",
179+
"dense_1 (Dense) (None, 128) 65664 \n",
180+
"_________________________________________________________________\n",
181+
"activation_1 (Activation) (None, 128) 0 \n",
182+
"_________________________________________________________________\n",
183+
"dense_2 (Dense) (None, 5) 645 \n",
184+
"_________________________________________________________________\n",
185+
"activation_2 (Activation) (None, 5) 0 \n",
186+
"=================================================================\n",
187+
"Total params: 81,117.0\n",
188+
"Trainable params: 81,117.0\n",
189+
"Non-trainable params: 0.0\n",
190+
"_________________________________________________________________\n"
191+
]
192+
}
193+
],
194+
"source": [
195+
"model = Sequential()\n",
196+
"\n",
197+
"model.add(Convolution2D(32, 5, input_shape=(28, 28, 1), activation='relu'))\n",
198+
"model.add(Convolution2D(16, 5, activation='relu'))\n",
199+
"model.add(MaxPool2D(pool_size=(2, 2)))\n",
200+
"model.add(Convolution2D(8, 3, activation='relu'))\n",
201+
"model.add(Flatten())\n",
202+
"model.add(Dropout(0.42))\n",
203+
"\n",
204+
"model.add(Dense(128))\n",
205+
"model.add(Activation('relu'))\n",
206+
"\n",
207+
"model.add(Dense(5))\n",
208+
"model.add(Activation('softmax'))\n",
209+
"\n",
210+
"model.summary()\n",
211+
"model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])"
212+
]
213+
},
214+
{
215+
"cell_type": "code",
216+
"execution_count": 13,
217+
"metadata": {
218+
"collapsed": false
219+
},
220+
"outputs": [
221+
{
222+
"name": "stdout",
223+
"output_type": "stream",
224+
"text": [
225+
"0:00:05.005437\n"
226+
]
227+
}
228+
],
229+
"source": [
230+
"# Add Time module to track training time\n",
231+
"import time\n",
232+
"import datetime\n",
233+
"\n",
234+
"a = datetime.datetime.now()\n",
235+
"time.sleep(5)\n",
236+
"print datetime.datetime.now() - a"
237+
]
238+
},
239+
{
240+
"cell_type": "code",
241+
"execution_count": 9,
242+
"metadata": {
243+
"collapsed": false
244+
},
245+
"outputs": [
246+
{
247+
"name": "stderr",
248+
"output_type": "stream",
249+
"text": [
250+
"/usr/local/lib/python2.7/dist-packages/keras/models.py:826: UserWarning: The `nb_epoch` argument in `fit` has been renamed `epochs`.\n",
251+
" warnings.warn('The `nb_epoch` argument in `fit` '\n"
252+
]
253+
},
254+
{
255+
"name": "stdout",
256+
"output_type": "stream",
257+
"text": [
258+
"Train on 16336 samples, validate on 4084 samples\n",
259+
"Epoch 1/10\n",
260+
"3s - loss: 0.2427 - acc: 0.9158 - val_loss: 0.0447 - val_acc: 0.9858\n",
261+
"Epoch 2/10\n",
262+
"2s - loss: 0.0595 - acc: 0.9809 - val_loss: 0.0329 - val_acc: 0.9890\n",
263+
"Epoch 3/10\n",
264+
"2s - loss: 0.0405 - acc: 0.9873 - val_loss: 0.0232 - val_acc: 0.9917\n",
265+
"Epoch 4/10\n",
266+
"2s - loss: 0.0321 - acc: 0.9890 - val_loss: 0.0176 - val_acc: 0.9946\n",
267+
"Epoch 5/10\n",
268+
"2s - loss: 0.0262 - acc: 0.9920 - val_loss: 0.0131 - val_acc: 0.9956\n",
269+
"Epoch 6/10\n",
270+
"2s - loss: 0.0202 - acc: 0.9935 - val_loss: 0.0245 - val_acc: 0.9922\n",
271+
"Epoch 7/10\n",
272+
"2s - loss: 0.0179 - acc: 0.9941 - val_loss: 0.0157 - val_acc: 0.9944\n",
273+
"Epoch 8/10\n",
274+
"2s - loss: 0.0173 - acc: 0.9944 - val_loss: 0.0168 - val_acc: 0.9949\n",
275+
"Epoch 9/10\n",
276+
"2s - loss: 0.0131 - acc: 0.9949 - val_loss: 0.0102 - val_acc: 0.9976\n",
277+
"Epoch 10/10\n",
278+
"2s - loss: 0.0143 - acc: 0.9961 - val_loss: 0.0130 - val_acc: 0.9966\n"
279+
]
280+
}
281+
],
282+
"source": [
283+
"start = datetime.datetime.now()\n",
284+
"hist1 = model.fit(x1_train, y1_train,\n",
285+
" nb_epoch=10,\n",
286+
" shuffle=True,\n",
287+
" batch_size=100,\n",
288+
" validation_data=(x1_val, y1_val), verbose=2)\n",
289+
"\n",
290+
"time_taken = datetime.datetime.now() - start"
291+
]
292+
},
293+
{
294+
"cell_type": "code",
295+
"execution_count": null,
296+
"metadata": {
297+
"collapsed": false
298+
},
299+
"outputs": [],
300+
"source": []
301+
}
302+
],
303+
"metadata": {
304+
"kernelspec": {
305+
"display_name": "Python 2",
306+
"language": "python",
307+
"name": "python2"
308+
},
309+
"language_info": {
310+
"codemirror_mode": {
311+
"name": "ipython",
312+
"version": 2
313+
},
314+
"file_extension": ".py",
315+
"mimetype": "text/x-python",
316+
"name": "python",
317+
"nbconvert_exporter": "python",
318+
"pygments_lexer": "ipython2",
319+
"version": "2.7.12"
320+
}
321+
},
322+
"nbformat": 4,
323+
"nbformat_minor": 2
324+
}

0 commit comments

Comments
 (0)