13
13
from matplotlib import pyplot as plt
14
14
from tqdm import tqdm
15
15
16
- from neurallogic import (hard_and , hard_majority , hard_not , hard_or , hard_xor , harden ,
17
- harden_layer , neural_logic_net , real_encoder )
16
+ from neurallogic import (
17
+ hard_and ,
18
+ hard_majority ,
19
+ hard_not ,
20
+ hard_or ,
21
+ hard_xor ,
22
+ harden ,
23
+ harden_layer ,
24
+ neural_logic_net ,
25
+ real_encoder ,
26
+ )
18
27
19
28
# Uncomment to debug NaNs
20
- #config.update("jax_debug_nans", True)
29
+ # config.update("jax_debug_nans", True)
21
30
22
31
"""
23
32
MNIST test.
@@ -44,15 +53,18 @@ def nln(type, x, width):
44
53
return x
45
54
"""
46
55
56
+
47
57
def nln (type , x ):
48
58
num_classes = 10
49
59
50
- x = hard_or .or_layer (type )(1800 , nn .initializers .uniform (1.0 ), dtype = jax .numpy .float16 )(x )
60
+ x = hard_or .or_layer (type )(
61
+ 1800 , nn .initializers .uniform (1.0 ), dtype = jax .numpy .float16
62
+ )(x )
51
63
x = hard_not .not_layer (type )(1 , dtype = jax .numpy .float16 )(x )
52
64
x = x .ravel ()
53
- x = harden_layer .harden_layer (type )(x )
54
- x = x .reshape ((num_classes , int (x .shape [0 ] / num_classes )))
55
- x = x .sum (- 1 )
65
+ x = harden_layer .harden_layer (type )(x )
66
+ x = x .reshape ((num_classes , int (x .shape [0 ] / num_classes )))
67
+ x = x .sum (- 1 )
56
68
return x
57
69
58
70
@@ -129,13 +141,13 @@ def get_datasets():
129
141
train_ds = tfds .as_numpy (ds_builder .as_dataset (split = "train" , batch_size = - 1 ))
130
142
test_ds = tfds .as_numpy (ds_builder .as_dataset (split = "test" , batch_size = - 1 ))
131
143
# XXXX
132
- train_ds ["image" ] = ( jnp .float32 (train_ds ["image" ]) / 255.0 )
133
- test_ds ["image" ] = ( jnp .float32 (test_ds ["image" ]) / 255.0 )
144
+ train_ds ["image" ] = jnp .float32 (train_ds ["image" ]) / 255.0
145
+ test_ds ["image" ] = jnp .float32 (test_ds ["image" ]) / 255.0
134
146
# TODO: we don't need to do this even when we don't use the real encoder
135
147
# Use grayscale information
136
148
# Convert the floating point values in [0,1] to binary values in {0,1}
137
- #train_ds["image"] = jnp.round(train_ds["image"])
138
- #test_ds["image"] = jnp.round(test_ds["image"])
149
+ # train_ds["image"] = jnp.round(train_ds["image"])
150
+ # test_ds["image"] = jnp.round(test_ds["image"])
139
151
return train_ds , test_ds
140
152
141
153
@@ -165,23 +177,16 @@ def create_train_state(net, rng, config):
165
177
# for NLN
166
178
mock_input = jnp .ones ([1 , 28 * 28 ])
167
179
soft_weights = net .init (rng , mock_input )["params" ]
168
- #tx = optax.sgd(config.learning_rate, config.momentum)
169
- #tx = optax.noisy_sgd(config.learning_rate, config.momentum)
180
+ # tx = optax.sgd(config.learning_rate, config.momentum)
181
+ # tx = optax.noisy_sgd(config.learning_rate, config.momentum)
170
182
tx = optax .yogi (config .learning_rate )
171
183
return train_state .TrainState .create (apply_fn = net .apply , params = soft_weights , tx = tx )
172
184
173
185
174
186
def train_and_evaluate (
175
187
net , datasets , config : ml_collections .ConfigDict , workdir : str
176
188
) -> train_state .TrainState :
177
- """Execute model training and evaluation loop.
178
- Args:
179
- config: Hyperparameter configuration for training and evaluation.
180
- workdir: Directory where the tensorboard summaries are written to.
181
- Returns:
182
- The train state (which includes the `.params`).
183
- """
184
- train_ds , test_ds = datasets
189
+ train_dataset , test_dataset = datasets
185
190
rng = jax .random .PRNGKey (0 )
186
191
187
192
summary_writer = tensorboard .SummaryWriter (workdir )
@@ -193,10 +198,10 @@ def train_and_evaluate(
193
198
for epoch in range (1 , config .num_epochs + 1 ):
194
199
rng , input_rng = jax .random .split (rng )
195
200
state , train_loss , train_accuracy = train_epoch (
196
- state , train_ds , config .batch_size , input_rng
201
+ state , train_dataset , config .batch_size , input_rng
197
202
)
198
203
_ , test_loss , test_accuracy = apply_model_with_grad (
199
- state , test_ds ["image" ], test_ds ["label" ]
204
+ state , test_dataset ["image" ], test_dataset ["label" ]
200
205
)
201
206
202
207
print (
@@ -219,13 +224,13 @@ def get_config():
219
224
# config for CNN
220
225
config .learning_rate = 0.01
221
226
# config for NLN
222
- #config.learning_rate = 0.1
227
+ # config.learning_rate = 0.1
223
228
config .learning_rate = 0.01
224
229
225
230
# Always commit with num_epochs = 1 for short test time
226
231
config .momentum = 0.9
227
232
config .batch_size = 128
228
- #config.num_epochs = 2
233
+ # config.num_epochs = 2
229
234
config .num_epochs = 1000
230
235
return config
231
236
@@ -290,6 +295,7 @@ def check_symbolic(nets, datasets, trained_state):
290
295
symbolic_output = symbolic .apply ({"params" : symbolic_weights }, symbolic_input )
291
296
print ("symbolic_output" , symbolic_output [0 ][:10000 ])
292
297
298
+
293
299
@pytest .mark .skip (reason = "temporarily off" )
294
300
def test_mnist ():
295
301
# Make sure tf does not allocate gpu memory.
@@ -311,13 +317,13 @@ def test_mnist():
311
317
312
318
print (soft .tabulate (jax .random .PRNGKey (0 ), train_ds ["image" ][0 :1 ]))
313
319
# TODO: fix the size of this
314
- #print(hard.tabulate(jax.random.PRNGKey(0), harden.harden(train_ds["image"][0:1])))
320
+ # print(hard.tabulate(jax.random.PRNGKey(0), harden.harden(train_ds["image"][0:1])))
315
321
316
322
# Train and evaluate the model.
317
323
trained_state = train_and_evaluate (
318
324
soft , (train_ds , test_ds ), config = config , workdir = "./mnist_metrics"
319
325
)
320
326
321
327
# Check symbolic net
322
- #_, hard, symbolic = neural_logic_net.net(lambda type, x: nln(type, x))
323
- #check_symbolic((soft, hard, symbolic), (train_ds, test_ds), trained_state)
328
+ # _, hard, symbolic = neural_logic_net.net(lambda type, x: nln(type, x))
329
+ # check_symbolic((soft, hard, symbolic), (train_ds, test_ds), trained_state)
0 commit comments