Skip to content

Commit 48d58e1

Browse files
committed
experiment with mask and majority
1 parent 2b30340 commit 48d58e1

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

tests/test_hard_not.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import jax
2-
import jax.numpy as jnp
32
import numpy
43
import optax
54
from flax.training import train_state

tests/test_mnist.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
hard_not,
1919
hard_or,
2020
hard_xor,
21+
hard_masks,
2122
harden,
2223
harden_layer,
2324
neural_logic_net,
@@ -89,7 +90,7 @@ def check_symbolic(nets, datasets, trained_state, dropout_rng):
8990
print("symbolic_output", symbolic_output[0][:10000])
9091

9192

92-
def nln(type, x, training: bool):
93+
def nln_1(type, x, training: bool):
9394
num_classes = 10
9495

9596
x = hard_or.or_layer(type)(
@@ -125,6 +126,23 @@ def nln_experimental(type, x, training: bool):
125126
return x
126127

127128

129+
def nln(type, x, training: bool):
130+
input_size = 784
131+
mask_layer_size = 10
132+
dtype = jax.numpy.float32
133+
x = hard_masks.mask_to_true_layer(type)(mask_layer_size, dtype=dtype)(x)
134+
x = x.reshape((int(mask_layer_size * 98), int(input_size / 98)))
135+
x = hard_majority.majority_layer(type)()(x)
136+
x = hard_not.not_layer(type)(20, dtype=dtype)(x)
137+
x = x.ravel()
138+
##############################
139+
x = harden_layer.harden_layer(type)(x)
140+
num_classes = 10
141+
x = x.reshape((num_classes, int(x.shape[0] / num_classes)))
142+
x = x.sum(-1)
143+
return x
144+
145+
128146
def batch_nln(type, x, training: bool):
129147
return jax.vmap(lambda x: nln(type, x, training))(x)
130148

@@ -299,13 +317,15 @@ def apply_hard_model_to_images(state, images, labels):
299317
def get_config():
300318
config = ml_collections.ConfigDict()
301319
# config for CNN: config.learning_rate = 0.01
302-
config.learning_rate = 0.1
320+
config.learning_rate = 0.01
303321
config.momentum = 0.9
304322
config.batch_size = 128
305323
config.num_epochs = 2
306324
return config
307325

308326

327+
# TODO: check my use of rng
328+
309329
# @pytest.mark.skip(reason="temporarily off")
310330
def test_mnist():
311331
# Make sure tf does not allocate gpu memory.

0 commit comments

Comments
 (0)