9
9
from flax .metrics import tensorboard
10
10
from flax .training import train_state
11
11
import ml_collections
12
- from neurallogic import (hard_not , hard_or , harden , harden_layer ,
13
- neural_logic_net )
12
+ from neurallogic import hard_not , hard_or , harden , harden_layer , neural_logic_net
14
13
import optax
15
14
16
15
23
22
24
23
25
24
def nln (type , x , width ):
26
- x = hard_or .or_layer (type )(width , nn .initializers .uniform (
27
- 1.0 ), dtype = jnp .float32 )(x ) # >=1700 need for >98% accuracy
28
- x = hard_not .not_layer (type )(10 , dtype = jnp .float32 )(x )
25
+ x = hard_or .or_layer (type )(width , nn .initializers .uniform (1.0 ))(x )
26
+ x = hard_not .not_layer (type )(10 )(x )
29
27
x = x .ravel () # flatten the outputs of the not layer
30
28
# harden the outputs of the not layer
31
29
x = harden_layer .harden_layer (type )(x )
@@ -59,11 +57,11 @@ def __call__(self, x):
59
57
@jax .jit
60
58
def apply_model_with_grad (state , images , labels ):
61
59
"""Computes gradients, loss and accuracy for a single batch."""
60
+
62
61
def loss_fn (params ):
63
- logits = state .apply_fn ({' params' : params }, images )
62
+ logits = state .apply_fn ({" params" : params }, images )
64
63
one_hot = jax .nn .one_hot (labels , 10 )
65
- loss = jnp .mean (optax .softmax_cross_entropy (
66
- logits = logits , labels = one_hot ))
64
+ loss = jnp .mean (optax .softmax_cross_entropy (logits = logits , labels = one_hot ))
67
65
return loss , logits
68
66
69
67
grad_fn = jax .value_and_grad (loss_fn , has_aux = True )
@@ -79,21 +77,20 @@ def update_model(state, grads):
79
77
80
78
def train_epoch (state , train_ds , batch_size , rng ):
81
79
"""Train for a single epoch."""
82
- train_ds_size = len (train_ds [' image' ])
80
+ train_ds_size = len (train_ds [" image" ])
83
81
steps_per_epoch = train_ds_size // batch_size
84
82
85
- perms = jax .random .permutation (rng , len (train_ds [' image' ]))
86
- perms = perms [:steps_per_epoch * batch_size ] # skip incomplete batch
83
+ perms = jax .random .permutation (rng , len (train_ds [" image" ]))
84
+ perms = perms [: steps_per_epoch * batch_size ] # skip incomplete batch
87
85
perms = perms .reshape ((steps_per_epoch , batch_size ))
88
86
89
87
epoch_loss = []
90
88
epoch_accuracy = []
91
89
92
90
for perm in perms :
93
- batch_images = train_ds ['image' ][perm , ...]
94
- batch_labels = train_ds ['label' ][perm , ...]
95
- grads , loss , accuracy = apply_model_with_grad (
96
- state , batch_images , batch_labels )
91
+ batch_images = train_ds ["image" ][perm , ...]
92
+ batch_labels = train_ds ["label" ][perm , ...]
93
+ grads , loss , accuracy = apply_model_with_grad (state , batch_images , batch_labels )
97
94
state = update_model (state , grads )
98
95
epoch_loss .append (loss )
99
96
epoch_accuracy .append (accuracy )
@@ -103,24 +100,23 @@ def train_epoch(state, train_ds, batch_size, rng):
103
100
104
101
105
102
def get_datasets ():
106
- ds_builder = tfds .builder (' mnist' )
103
+ ds_builder = tfds .builder (" mnist" )
107
104
ds_builder .download_and_prepare ()
108
- train_ds = tfds .as_numpy (
109
- ds_builder .as_dataset (split = 'train' , batch_size = - 1 ))
110
- test_ds = tfds .as_numpy (ds_builder .as_dataset (split = 'test' , batch_size = - 1 ))
111
- train_ds ['image' ] = jnp .float32 (train_ds ['image' ]) / 255.
112
- test_ds ['image' ] = jnp .float32 (test_ds ['image' ]) / 255.
105
+ train_ds = tfds .as_numpy (ds_builder .as_dataset (split = "train" , batch_size = - 1 ))
106
+ test_ds = tfds .as_numpy (ds_builder .as_dataset (split = "test" , batch_size = - 1 ))
107
+ train_ds ["image" ] = jnp .float32 (train_ds ["image" ]) / 255.0
108
+ test_ds ["image" ] = jnp .float32 (test_ds ["image" ]) / 255.0
113
109
# Convert the floating point values in [0,1] to binary values in {0,1}
114
- train_ds [' image' ] = jnp .round (train_ds [' image' ])
115
- test_ds [' image' ] = jnp .round (test_ds [' image' ])
110
+ train_ds [" image" ] = jnp .round (train_ds [" image" ])
111
+ test_ds [" image" ] = jnp .round (test_ds [" image" ])
116
112
return train_ds , test_ds
117
113
118
114
119
115
def show_img (img , ax = None , title = None ):
120
116
"""Shows a single image."""
121
117
if ax is None :
122
118
ax = plt .gca ()
123
- ax .imshow (img .reshape (28 , 28 ), cmap = ' gray' )
119
+ ax .imshow (img .reshape (28 , 28 ), cmap = " gray" )
124
120
ax .set_xticks ([])
125
121
ax .set_yticks ([])
126
122
if title :
@@ -129,7 +125,7 @@ def show_img(img, ax=None, title=None):
129
125
130
126
def show_img_grid (imgs , titles ):
131
127
"""Shows a grid of images."""
132
- n = int (np .ceil (len (imgs )** .5 ))
128
+ n = int (np .ceil (len (imgs ) ** 0 .5 ))
133
129
_ , axs = plt .subplots (n , n , figsize = (3 * n , 3 * n ))
134
130
for i , (img , title ) in enumerate (zip (imgs , titles )):
135
131
show_img (img , axs [i // n ][i % n ], title )
@@ -141,13 +137,14 @@ def create_train_state(net, rng, config):
141
137
# mock_input = jnp.ones([1, 28, 28, 1])
142
138
# for NLN
143
139
mock_input = jnp .ones ([1 , 28 * 28 ])
144
- soft_weights = net .init (rng , mock_input )[' params' ]
140
+ soft_weights = net .init (rng , mock_input )[" params" ]
145
141
tx = optax .sgd (config .learning_rate , config .momentum )
146
142
return train_state .TrainState .create (apply_fn = net .apply , params = soft_weights , tx = tx )
147
143
148
144
149
- def train_and_evaluate (net , datasets , config : ml_collections .ConfigDict ,
150
- workdir : str ) -> train_state .TrainState :
145
+ def train_and_evaluate (
146
+ net , datasets , config : ml_collections .ConfigDict , workdir : str
147
+ ) -> train_state .TrainState :
151
148
"""Execute model training and evaluation loop.
152
149
Args:
153
150
config: Hyperparameter configuration for training and evaluation.
@@ -166,21 +163,22 @@ def train_and_evaluate(net, datasets, config: ml_collections.ConfigDict,
166
163
167
164
for epoch in range (1 , config .num_epochs + 1 ):
168
165
rng , input_rng = jax .random .split (rng )
169
- state , train_loss , train_accuracy = train_epoch (state , train_ds ,
170
- config .batch_size ,
171
- input_rng )
172
- _ , test_loss , test_accuracy = apply_model_with_grad (state , test_ds ['image' ],
173
- test_ds ['label' ])
166
+ state , train_loss , train_accuracy = train_epoch (
167
+ state , train_ds , config .batch_size , input_rng
168
+ )
169
+ _ , test_loss , test_accuracy = apply_model_with_grad (
170
+ state , test_ds ["image" ], test_ds ["label" ]
171
+ )
174
172
175
173
print (
176
- ' epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f'
177
- % (epoch , train_loss , train_accuracy * 100 , test_loss ,
178
- test_accuracy * 100 ) )
174
+ " epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f"
175
+ % (epoch , train_loss , train_accuracy * 100 , test_loss , test_accuracy * 100 )
176
+ )
179
177
180
- summary_writer .scalar (' train_loss' , train_loss , epoch )
181
- summary_writer .scalar (' train_accuracy' , train_accuracy , epoch )
182
- summary_writer .scalar (' test_loss' , test_loss , epoch )
183
- summary_writer .scalar (' test_accuracy' , test_accuracy , epoch )
178
+ summary_writer .scalar (" train_loss" , train_loss , epoch )
179
+ summary_writer .scalar (" train_accuracy" , train_accuracy , epoch )
180
+ summary_writer .scalar (" test_loss" , test_loss , epoch )
181
+ summary_writer .scalar (" test_accuracy" , test_accuracy , epoch )
184
182
185
183
return state
186
184
@@ -197,13 +195,13 @@ def get_config():
197
195
# Always commit with num_epochs = 1 for short test time
198
196
config .momentum = 0.9
199
197
config .batch_size = 128
200
- config .num_epochs = 2
198
+ config .num_epochs = 1000
201
199
return config
202
200
203
201
204
202
def apply_hard_model (state , image , label ):
205
203
def logits_fn (params ):
206
- return state .apply_fn ({' params' : params }, image )
204
+ return state .apply_fn ({" params" : params }, image )
207
205
208
206
logits = logits_fn (state .params )
209
207
if isinstance (logits , list ):
@@ -224,33 +222,41 @@ def check_symbolic(nets, datasets, trained_state):
224
222
_ , test_ds = datasets
225
223
_ , hard , symbolic = nets
226
224
_ , test_loss , test_accuracy = apply_model_with_grad (
227
- trained_state , test_ds ['image' ], test_ds ['label' ])
228
- print ('soft_net: final test_loss: %.4f, final test_accuracy: %.2f' %
229
- (test_loss , test_accuracy * 100 ))
225
+ trained_state , test_ds ["image" ], test_ds ["label" ]
226
+ )
227
+ print (
228
+ "soft_net: final test_loss: %.4f, final test_accuracy: %.2f"
229
+ % (test_loss , test_accuracy * 100 )
230
+ )
230
231
hard_weights = harden .hard_weights (trained_state .params )
231
232
hard_trained_state = train_state .TrainState .create (
232
- apply_fn = hard .apply , params = hard_weights , tx = optax .sgd (1.0 , 1.0 ))
233
- hard_input = harden .harden (test_ds ['image' ])
233
+ apply_fn = hard .apply , params = hard_weights , tx = optax .sgd (1.0 , 1.0 )
234
+ )
235
+ hard_input = harden .harden (test_ds ["image" ])
234
236
hard_test_accuracy = apply_hard_model_to_images (
235
- hard_trained_state , hard_input , test_ds ['label' ])
236
- print ('hard_net: final test_accuracy: %.2f' % (hard_test_accuracy * 100 ))
237
+ hard_trained_state , hard_input , test_ds ["label" ]
238
+ )
239
+ print ("hard_net: final test_accuracy: %.2f" % (hard_test_accuracy * 100 ))
237
240
assert np .isclose (test_accuracy , hard_test_accuracy , atol = 0.0001 )
241
+ # TODO: activate these checks
238
242
if False :
239
243
# It takes too long to compute this
240
244
symbolic_weights = harden .symbolic_weights (trained_state .params )
241
245
symbolic_trained_state = train_state .TrainState .create (
242
- apply_fn = symbolic .apply , params = symbolic_weights , tx = optax .sgd (1.0 , 1.0 ))
246
+ apply_fn = symbolic .apply , params = symbolic_weights , tx = optax .sgd (1.0 , 1.0 )
247
+ )
243
248
symbolic_input = hard_input .tolist ()
244
249
symbolic_test_accuracy = apply_hard_model_to_images (
245
- symbolic_trained_state , symbolic_input , test_ds ['label' ])
246
- print ('symbolic_net: final test_accuracy: %.2f' %
247
- (symbolic_test_accuracy * 100 ))
248
- assert (np .isclose (test_accuracy , symbolic_test_accuracy , atol = 0.0001 ))
250
+ symbolic_trained_state , symbolic_input , test_ds ["label" ]
251
+ )
252
+ print (
253
+ "symbolic_net: final test_accuracy: %.2f" % (symbolic_test_accuracy * 100 )
254
+ )
255
+ assert np .isclose (test_accuracy , symbolic_test_accuracy , atol = 0.0001 )
249
256
if False :
250
257
# CPU and GPU give different results, so we can't easily regress on a static symbolic expression
251
258
symbolic_input = [f"x{ i } " for i in range (len (hard_input [0 ].tolist ()))]
252
- symbolic_output = symbolic .apply (
253
- {'params' : symbolic_weights }, symbolic_input )
259
+ symbolic_output = symbolic .apply ({"params" : symbolic_weights }, symbolic_input )
254
260
print ("symbolic_output" , symbolic_output [0 ][:10000 ])
255
261
256
262
@@ -263,23 +269,20 @@ def test_mnist():
263
269
264
270
# Define the model.
265
271
# soft = CNN()
266
- width = 100
267
- soft , _ , _ = neural_logic_net .net (
268
- lambda type , x : batch_nln (type , x , width ))
272
+ width = 1000
273
+ soft , _ , _ = neural_logic_net .net (lambda type , x : batch_nln (type , x , width ))
269
274
270
275
# Get the MNIST dataset.
271
276
train_ds , test_ds = get_datasets ()
272
277
# If we're using a NLN then flatten the images
273
- train_ds ["image" ] = jnp .reshape (
274
- train_ds ["image" ], (train_ds ["image" ].shape [0 ], - 1 ))
275
- test_ds ["image" ] = jnp .reshape (
276
- test_ds ["image" ], (test_ds ["image" ].shape [0 ], - 1 ))
278
+ train_ds ["image" ] = jnp .reshape (train_ds ["image" ], (train_ds ["image" ].shape [0 ], - 1 ))
279
+ test_ds ["image" ] = jnp .reshape (test_ds ["image" ], (test_ds ["image" ].shape [0 ], - 1 ))
277
280
278
281
# Train and evaluate the model.
279
282
trained_state = train_and_evaluate (
280
- soft , (train_ds , test_ds ), config = config , workdir = "./mnist_metrics" )
283
+ soft , (train_ds , test_ds ), config = config , workdir = "./mnist_metrics"
284
+ )
281
285
282
286
# Check symbolic net
283
- _ , hard , symbolic = neural_logic_net .net (
284
- lambda type , x : nln (type , x , width ))
287
+ _ , hard , symbolic = neural_logic_net .net (lambda type , x : nln (type , x , width ))
285
288
check_symbolic ((soft , hard , symbolic ), (train_ds , test_ds ), trained_state )
0 commit comments