Skip to content

Commit 255354b

Browse files
committed
Use stable keras and keras-hub, bug fixes
1 parent 944c7d3 commit 255354b

16 files changed

+155
-118
lines changed

chapter02_mathematical-building-blocks.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
},
1818
"outputs": [],
1919
"source": [
20-
"!pip install keras-nightly keras-hub-nightly --upgrade -q"
20+
"!pip install keras keras-hub --upgrade -q"
2121
]
2222
},
2323
{
@@ -430,7 +430,7 @@
430430
"\n",
431431
"digit = train_images[4]\n",
432432
"plt.imshow(digit, cmap=plt.cm.binary)\n",
433-
"plt.savefig(\"The-fourth-sample-in-our-dataset.png\", dpi=300)"
433+
"plt.show()"
434434
]
435435
},
436436
{

chapter03_introduction-to-ml-frameworks.ipynb

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
},
1818
"outputs": [],
1919
"source": [
20-
"!pip install keras-nightly keras-hub-nightly --upgrade -q"
20+
"!pip install keras keras-hub --upgrade -q"
2121
]
2222
},
2323
{
@@ -159,7 +159,7 @@
159159
"colab_type": "text"
160160
},
161161
"source": [
162-
"###### Tensor assignment and the Variable class"
162+
"###### Tensor assignment and the `Variable` class"
163163
]
164164
},
165165
{
@@ -413,7 +413,7 @@
413413
"import matplotlib.pyplot as plt\n",
414414
"\n",
415415
"plt.scatter(inputs[:, 0], inputs[:, 1], c=targets[:, 0])\n",
416-
"plt.savefig(\"linear_model_inputs.png\", dpi=300)"
416+
"plt.show()"
417417
]
418418
},
419419
{
@@ -499,7 +499,7 @@
499499
"source": [
500500
"predictions = model(inputs, W, b)\n",
501501
"plt.scatter(inputs[:, 0], inputs[:, 1], c=predictions[:, 0] > 0.5)\n",
502-
"plt.savefig(\"linear_model_predictions.png\", dpi=300)"
502+
"plt.show()"
503503
]
504504
},
505505
{
@@ -684,7 +684,7 @@
684684
"c = torch.sqrt(a)\n",
685685
"d = b + c\n",
686686
"e = torch.matmul(a, b)\n",
687-
"f = torch.cat((a, b), axis=0)"
687+
"f = torch.cat((a, b), dim=0)"
688688
]
689689
},
690690
{
@@ -825,7 +825,7 @@
825825
"colab_type": "text"
826826
},
827827
"source": [
828-
"##### Packaging state and computation with Modules"
828+
"##### Packaging state and computation with the `Module` class"
829829
]
830830
},
831831
{
@@ -914,7 +914,7 @@
914914
},
915915
"outputs": [],
916916
"source": [
917-
"compiled_model = model.compile()"
917+
"compiled_model = torch.compile(model)"
918918
]
919919
},
920920
{
@@ -1006,7 +1006,7 @@
10061006
"colab_type": "text"
10071007
},
10081008
"source": [
1009-
"#### Random tensors"
1009+
"#### Random number generation in JAX"
10101010
]
10111011
},
10121012
{
@@ -1060,7 +1060,7 @@
10601060
"source": [
10611061
"import jax\n",
10621062
"\n",
1063-
"seed_key = jax.random.PRNGKey(1337)"
1063+
"seed_key = jax.random.key(1337)"
10641064
]
10651065
},
10661066
{
@@ -1071,7 +1071,7 @@
10711071
},
10721072
"outputs": [],
10731073
"source": [
1074-
"seed_key = jax.random.PRNGKey(0)\n",
1074+
"seed_key = jax.random.key(0)\n",
10751075
"jax.random.normal(seed_key, shape=(3,))"
10761076
]
10771077
},
@@ -1083,7 +1083,7 @@
10831083
},
10841084
"outputs": [],
10851085
"source": [
1086-
"seed_key = jax.random.PRNGKey(123)\n",
1086+
"seed_key = jax.random.key(123)\n",
10871087
"jax.random.normal(seed_key, shape=(3,))"
10881088
]
10891089
},
@@ -1106,7 +1106,7 @@
11061106
},
11071107
"outputs": [],
11081108
"source": [
1109-
"seed_key = jax.random.PRNGKey(123)\n",
1109+
"seed_key = jax.random.key(123)\n",
11101110
"jax.random.normal(seed_key, shape=(3,))"
11111111
]
11121112
},
@@ -1353,6 +1353,8 @@
13531353
},
13541354
"outputs": [],
13551355
"source": [
1356+
"learning_rate = 0.1\n",
1357+
"\n",
13561358
"@jax.jit\n",
13571359
"def training_step(inputs, targets, W, b):\n",
13581360
" loss, grads = grad_fn((W, b), inputs, targets)\n",

chapter04_classification-and-regression.ipynb

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
},
1818
"outputs": [],
1919
"source": [
20-
"!pip install keras-nightly keras-hub-nightly --upgrade -q"
20+
"!pip install keras keras-hub --upgrade -q"
2121
]
2222
},
2323
{
@@ -172,8 +172,8 @@
172172
},
173173
"outputs": [],
174174
"source": [
175-
"y_train = np.asarray(train_labels).astype(\"float32\")\n",
176-
"y_test = np.asarray(test_labels).astype(\"float32\")"
175+
"y_train = train_labels.astype(\"float32\")\n",
176+
"y_test = test_labels.astype(\"float32\")"
177177
]
178178
},
179179
{
@@ -310,7 +310,7 @@
310310
"plt.xticks(epochs)\n",
311311
"plt.ylabel(\"Loss\")\n",
312312
"plt.legend()\n",
313-
"plt.savefig(\"imdb_loss_plot.png\", dpi=300)"
313+
"plt.show()"
314314
]
315315
},
316316
{
@@ -331,7 +331,7 @@
331331
"plt.xticks(epochs)\n",
332332
"plt.ylabel(\"Accuracy\")\n",
333333
"plt.legend()\n",
334-
"plt.savefig(\"imdb_accuracy_plot.png\", dpi=300)"
334+
"plt.show()"
335335
]
336336
},
337337
{
@@ -654,7 +654,7 @@
654654
"plt.xticks(epochs)\n",
655655
"plt.ylabel(\"Loss\")\n",
656656
"plt.legend()\n",
657-
"plt.savefig(\"reuters_loss_plot.png\", dpi=300)"
657+
"plt.show()"
658658
]
659659
},
660660
{
@@ -675,7 +675,7 @@
675675
"plt.xticks(epochs)\n",
676676
"plt.ylabel(\"Accuracy\")\n",
677677
"plt.legend()\n",
678-
"plt.savefig(\"reuters_accuracy_plot.png\", dpi=300)"
678+
"plt.show()"
679679
]
680680
},
681681
{
@@ -696,7 +696,7 @@
696696
"plt.xticks(epochs)\n",
697697
"plt.ylabel(\"Top-3 accuracy\")\n",
698698
"plt.legend()\n",
699-
"plt.savefig(\"reuters_top_3_accuracy_plot.png\", dpi=300)"
699+
"plt.show()"
700700
]
701701
},
702702
{
@@ -750,7 +750,7 @@
750750
"import copy\n",
751751
"test_labels_copy = copy.copy(test_labels)\n",
752752
"np.random.shuffle(test_labels_copy)\n",
753-
"hits_array = np.array(test_labels)\n",
753+
"hits_array = np.array(test_labels == test_labels_copy)\n",
754754
"hits_array.mean()"
755755
]
756756
},
@@ -824,8 +824,8 @@
824824
},
825825
"outputs": [],
826826
"source": [
827-
"y_train = np.array(train_labels)\n",
828-
"y_test = np.array(test_labels)"
827+
"y_train = train_labels\n",
828+
"y_test = test_labels"
829829
]
830830
},
831831
{
@@ -1163,7 +1163,7 @@
11631163
"plt.plot(epochs, average_mae_history)\n",
11641164
"plt.xlabel(\"Epochs\")\n",
11651165
"plt.ylabel(\"Validation MAE\")\n",
1166-
"plt.savefig(\"california_housing_validation_mae_plot.png\", dpi=300)"
1166+
"plt.show()"
11671167
]
11681168
},
11691169
{
@@ -1179,7 +1179,7 @@
11791179
"plt.plot(epochs, truncated_mae_history)\n",
11801180
"plt.xlabel(\"Epochs\")\n",
11811181
"plt.ylabel(\"Validation MAE\")\n",
1182-
"plt.savefig(\"california_housing_validation_mae_plot_zoomed.png\", dpi=300)"
1182+
"plt.show()"
11831183
]
11841184
},
11851185
{
@@ -1225,7 +1225,7 @@
12251225
},
12261226
"outputs": [],
12271227
"source": [
1228-
"predictions = model.predict(test_data)\n",
1228+
"predictions = model.predict(x_test)\n",
12291229
"predictions[0]"
12301230
]
12311231
},

chapter05_fundamentals-of-ml.ipynb

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
},
1818
"outputs": [],
1919
"source": [
20-
"!pip install keras-nightly keras-hub-nightly --upgrade -q"
20+
"!pip install keras keras-hub --upgrade -q"
2121
]
2222
},
2323
{
@@ -175,7 +175,7 @@
175175
"plt.xticks(epochs)\n",
176176
"plt.ylabel(\"Accuracy\")\n",
177177
"plt.legend()\n",
178-
"plt.savefig(\"mnist_with_added_noise_channels_or_zeros_channels.png\", dpi=300)"
178+
"plt.show()"
179179
]
180180
},
181181
{
@@ -445,7 +445,7 @@
445445
"plt.xlabel(\"Epochs\")\n",
446446
"plt.ylabel(\"Loss\")\n",
447447
"plt.legend()\n",
448-
"plt.savefig(\"effect_of_insufficient_model_capacity_on_val_loss.png\", dpi=300)"
448+
"plt.show()"
449449
]
450450
},
451451
{
@@ -485,14 +485,14 @@
485485
},
486486
"outputs": [],
487487
"source": [
488-
"val_loss = history_small_model.history[\"val_loss\"]\n",
488+
"val_loss = history_large_model.history[\"val_loss\"]\n",
489489
"epochs = range(1, 21)\n",
490490
"plt.plot(epochs, val_loss, \"b-\", label=\"Validation loss\")\n",
491491
"plt.title(\"Validation loss for a model with appropriate capacity\")\n",
492492
"plt.xlabel(\"Epochs\")\n",
493493
"plt.ylabel(\"Loss\")\n",
494494
"plt.legend()\n",
495-
"plt.savefig(\"effect_of_correct_model_capacity_on_val_loss.png\", dpi=300)"
495+
"plt.show()"
496496
]
497497
},
498498
{
@@ -507,6 +507,7 @@
507507
" [\n",
508508
" layers.Dense(2048, activation=\"relu\"),\n",
509509
" layers.Dense(2048, activation=\"relu\"),\n",
510+
" layers.Dense(2048, activation=\"relu\"),\n",
510511
" layers.Dense(10, activation=\"softmax\"),\n",
511512
" ]\n",
512513
")\n",
@@ -519,7 +520,7 @@
519520
" train_images,\n",
520521
" train_labels,\n",
521522
" epochs=20,\n",
522-
" batch_size=128,\n",
523+
" batch_size=32,\n",
523524
" validation_split=0.2,\n",
524525
")"
525526
]
@@ -539,7 +540,7 @@
539540
"plt.xlabel(\"Epochs\")\n",
540541
"plt.ylabel(\"Loss\")\n",
541542
"plt.legend()\n",
542-
"plt.savefig(\"effect_of_excessive_model_capacity_on_val_loss.png\", dpi=300)"
543+
"plt.show()"
543544
]
544545
},
545546
{
@@ -694,7 +695,7 @@
694695
"plt.ylabel(\"Loss\")\n",
695696
"plt.xticks(epochs)\n",
696697
"plt.legend()\n",
697-
"plt.savefig(\"original_model_vs_smaller_model_imdb.png\", dpi=300)"
698+
"plt.show()"
698699
]
699700
},
700701
{
@@ -735,7 +736,7 @@
735736
"outputs": [],
736737
"source": [
737738
"original_val_loss = history_original.history[\"val_loss\"]\n",
738-
"larger_model_val_loss = history_smaller_model.history[\"val_loss\"]\n",
739+
"larger_model_val_loss = history_larger_model.history[\"val_loss\"]\n",
739740
"epochs = range(1, 21)\n",
740741
"plt.plot(\n",
741742
" epochs,\n",
@@ -754,7 +755,7 @@
754755
"plt.ylabel(\"Loss\")\n",
755756
"plt.xticks(epochs)\n",
756757
"plt.legend()\n",
757-
"plt.savefig(\"original_model_vs_larger_model_imdb.png\", dpi=300)"
758+
"plt.show()"
758759
]
759760
},
760761
{
@@ -827,7 +828,7 @@
827828
"plt.ylabel(\"Loss\")\n",
828829
"plt.xticks(epochs)\n",
829830
"plt.legend()\n",
830-
"plt.savefig(\"original_model_vs_l2_regularized_model_imdb.png\", dpi=300)"
831+
"plt.show()"
831832
]
832833
},
833834
{
@@ -914,7 +915,7 @@
914915
"plt.ylabel(\"Loss\")\n",
915916
"plt.xticks(epochs)\n",
916917
"plt.legend()\n",
917-
"plt.savefig(\"original_model_vs_dropout_regularized_model_imdb.png\", dpi=300)"
918+
"plt.show()"
918919
]
919920
},
920921
{

0 commit comments

Comments
 (0)