Skip to content

Commit 3bcb6d9

Browse files
committed
experimenting with archs
1 parent d33001b commit 3bcb6d9

File tree

3 files changed

+46
-10
lines changed

3 files changed

+46
-10
lines changed

neurallogic/harden_layer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def straight_through_harden_element(x):
1818
def hard_harden_layer(x):
1919
return x
2020

21+
#TODO: can we harden arbitrary tensors?
2122
#TODO: is this correct?
2223
def symbolic_harden_layer(x):
2324
return x

tests/test_hard_majority.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,4 @@ def test_majority_layer():
163163
symbolic_output = symbolic_generation.symbolic_expression(jaxpr, harden.harden(input))
164164
assert jax.numpy.array_equal(symbolic_output, harden.harden(expected))
165165

166+
# TODO: test training the hard majority layer

tests/test_mnist.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,21 @@
33
import ml_collections
44
import numpy as np
55
import optax
6+
import pytest
67
import tensorflow as tf
78
import tensorflow_datasets as tfds
89
from flax import linen as nn
910
from flax.metrics import tensorboard
1011
from flax.training import train_state
12+
from jax.config import config
1113
from matplotlib import pyplot as plt
1214
from tqdm import tqdm
1315

14-
from neurallogic import (hard_not, hard_or, harden, harden_layer,
15-
neural_logic_net)
16+
from neurallogic import (hard_and, hard_majority, hard_not, hard_or, harden,
17+
harden_layer, neural_logic_net, real_encoder)
18+
19+
# Uncomment to debug NaNs
20+
#config.update("jax_debug_nans", True)
1621

1722
"""
1823
MNIST test.
@@ -21,16 +26,39 @@
2126
The data is loaded using tensorflow_datasets.
2227
"""
2328

24-
29+
# TODO: experiment in ipython notebook with different values for these
30+
"""
2531
def nln(type, x, width):
26-
x = hard_or.or_layer(type)(width, nn.initializers.uniform(1.0))(x)
32+
#x = x.reshape((-1, 1))
33+
#re = real_encoder.real_encoder_layer(type)(3)
34+
#x = jax.vmap(re, 0)(x)
35+
#x = x.ravel()
36+
#x = hard_or.or_layer(type)(width, nn.initializers.uniform(1.0), dtype=jax.numpy.float16)(x)
37+
x = hard_or.or_layer(type)(width, nn.initializers.uniform(1.0), dtype=jax.numpy.float16)(x)
2738
x = hard_not.not_layer(type)(10)(x)
2839
x = x.ravel() # flatten the outputs of the not layer
2940
# harden the outputs of the not layer
3041
x = harden_layer.harden_layer(type)(x)
3142
x = x.reshape((10, width)) # reshape to 10 ports, 100 bits each
3243
x = x.sum(-1) # sum the 100 bits in each port
3344
return x
45+
"""
46+
47+
def nln(type, x, width):
48+
majority_size = 3
49+
n = int(width / majority_size)
50+
not_size = 200
51+
num_classes = 10
52+
53+
x = hard_or.or_layer(type)(width, nn.initializers.uniform(1.0), dtype=jax.numpy.float16)(x) # width number of or neurons
54+
x = x.reshape((n, majority_size)) # reshape to (n, majority_size)
55+
x = hard_majority.majority_layer(type)(x) # reduce to (n,)
56+
x = hard_not.not_layer(type)(not_size, dtype=jax.numpy.float16)(x) # (not_size, n)
57+
x = x.ravel() # flatten the outputs of the not layer (not_size * n,)
58+
x = harden_layer.harden_layer(type)(x) # (not_size * n,)
59+
x = x.reshape((num_classes, int(not_size * n / num_classes))) # reshape to num_classes ports (num_classes, (not_size * n) / num_classes)
60+
x = x.sum(-1) # sum the bits in each port
61+
return x
3462

3563

3664
def batch_nln(type, x, width):
@@ -107,6 +135,8 @@ def get_datasets():
107135
test_ds = tfds.as_numpy(ds_builder.as_dataset(split="test", batch_size=-1))
108136
train_ds["image"] = jnp.float32(train_ds["image"]) / 255.0
109137
test_ds["image"] = jnp.float32(test_ds["image"]) / 255.0
138+
# TODO: we don't need to do this even when we don't use the real encoder
139+
# Use grayscale information
110140
# Convert the floating point values in [0,1] to binary values in {0,1}
111141
train_ds["image"] = jnp.round(train_ds["image"])
112142
test_ds["image"] = jnp.round(test_ds["image"])
@@ -191,12 +221,14 @@ def get_config():
191221
# config for CNN
192222
config.learning_rate = 0.01
193223
# config for NLN
194-
config.learning_rate = 0.1
224+
#config.learning_rate = 0.1
225+
config.learning_rate = 0.01
195226

196227
# Always commit with num_epochs = 1 for short test time
197228
config.momentum = 0.9
198229
config.batch_size = 128
199-
config.num_epochs = 2
230+
#config.num_epochs = 2
231+
config.num_epochs = 1000
200232
return config
201233

202234

@@ -260,7 +292,7 @@ def check_symbolic(nets, datasets, trained_state):
260292
symbolic_output = symbolic.apply({"params": symbolic_weights}, symbolic_input)
261293
print("symbolic_output", symbolic_output[0][:10000])
262294

263-
295+
@pytest.mark.skip(reason="temporarily off")
264296
def test_mnist():
265297
# Make sure tf does not allocate gpu memory.
266298
tf.config.experimental.set_visible_devices([], "GPU")
@@ -270,7 +302,7 @@ def test_mnist():
270302

271303
# Define the model.
272304
# soft = CNN()
273-
width = 10
305+
width = 1599
274306
soft, _, _ = neural_logic_net.net(lambda type, x: batch_nln(type, x, width))
275307

276308
# Get the MNIST dataset.
@@ -279,11 +311,13 @@ def test_mnist():
279311
train_ds["image"] = jnp.reshape(train_ds["image"], (train_ds["image"].shape[0], -1))
280312
test_ds["image"] = jnp.reshape(test_ds["image"], (test_ds["image"].shape[0], -1))
281313

314+
print(soft.tabulate(jax.random.PRNGKey(0), train_ds["image"][0:1]))
315+
282316
# Train and evaluate the model.
283317
trained_state = train_and_evaluate(
284318
soft, (train_ds, test_ds), config=config, workdir="./mnist_metrics"
285319
)
286320

287321
# Check symbolic net
288-
_, hard, symbolic = neural_logic_net.net(lambda type, x: nln(type, x, width))
289-
check_symbolic((soft, hard, symbolic), (train_ds, test_ds), trained_state)
322+
#_, hard, symbolic = neural_logic_net.net(lambda type, x: nln(type, x, width))
323+
#check_symbolic((soft, hard, symbolic), (train_ds, test_ds), trained_state)

0 commit comments

Comments
 (0)