|
18 | 18 | hard_not,
|
19 | 19 | hard_or,
|
20 | 20 | hard_xor,
|
| 21 | + hard_masks, |
21 | 22 | harden,
|
22 | 23 | harden_layer,
|
23 | 24 | neural_logic_net,
|
@@ -89,7 +90,7 @@ def check_symbolic(nets, datasets, trained_state, dropout_rng):
|
89 | 90 | print("symbolic_output", symbolic_output[0][:10000])
|
90 | 91 |
|
91 | 92 |
|
92 |
| -def nln(type, x, training: bool): |
| 93 | +def nln_1(type, x, training: bool): |
93 | 94 | num_classes = 10
|
94 | 95 |
|
95 | 96 | x = hard_or.or_layer(type)(
|
@@ -125,6 +126,23 @@ def nln_experimental(type, x, training: bool):
|
125 | 126 | return x
|
126 | 127 |
|
127 | 128 |
|
| 129 | +def nln(type, x, training: bool): |
| 130 | + input_size = 784 |
| 131 | + mask_layer_size = 10 |
| 132 | + dtype = jax.numpy.float32 |
| 133 | + x = hard_masks.mask_to_true_layer(type)(mask_layer_size, dtype=dtype)(x) |
| 134 | + x = x.reshape((int(mask_layer_size * 98), int(input_size / 98))) |
| 135 | + x = hard_majority.majority_layer(type)()(x) |
| 136 | + x = hard_not.not_layer(type)(20, dtype=dtype)(x) |
| 137 | + x = x.ravel() |
| 138 | + ############################## |
| 139 | + x = harden_layer.harden_layer(type)(x) |
| 140 | + num_classes = 10 |
| 141 | + x = x.reshape((num_classes, int(x.shape[0] / num_classes))) |
| 142 | + x = x.sum(-1) |
| 143 | + return x |
| 144 | + |
| 145 | + |
128 | 146 | def batch_nln(type, x, training: bool):
|
129 | 147 | return jax.vmap(lambda x: nln(type, x, training))(x)
|
130 | 148 |
|
@@ -299,13 +317,15 @@ def apply_hard_model_to_images(state, images, labels):
|
299 | 317 | def get_config():
|
300 | 318 | config = ml_collections.ConfigDict()
|
301 | 319 | # config for CNN: config.learning_rate = 0.01
|
302 |
| - config.learning_rate = 0.1 |
| 320 | + config.learning_rate = 0.01 |
303 | 321 | config.momentum = 0.9
|
304 | 322 | config.batch_size = 128
|
305 | 323 | config.num_epochs = 2
|
306 | 324 | return config
|
307 | 325 |
|
308 | 326 |
|
| 327 | +# TODO: check my use of rng |
| 328 | + |
309 | 329 | # @pytest.mark.skip(reason="temporarily off")
|
310 | 330 | def test_mnist():
|
311 | 331 | # Make sure tf does not allocate gpu memory.
|
|
0 commit comments