Skip to content

Commit c81c44e

Browse files
committed
100
1 parent 27a4d9f commit c81c44e

File tree

3 files changed

+217
-38
lines changed

3 files changed

+217
-38
lines changed

neurallogic/hard_xor.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import jax
44
from flax import linen as nn
55

6-
from neurallogic import neural_logic_net, symbolic_generation
6+
from neurallogic import neural_logic_net, symbolic_generation, hard_and
77

88

99
def soft_xor_include(w: float, x: float) -> float:
@@ -28,6 +28,7 @@ def soft_xor_neuron(w, x):
2828

2929
def xor(x, y):
3030
return jax.numpy.minimum(jax.numpy.maximum(x, y), 1.0 - jax.numpy.minimum(x, y))
31+
3132
x = jax.lax.reduce(x, jax.numpy.array(0, dtype=x.dtype), xor, (0,))
3233
return x
3334

@@ -45,15 +46,19 @@ def hard_xor_neuron(w, x):
4546

4647
class SoftXorLayer(nn.Module):
4748
layer_size: int
48-
weights_init: Callable = nn.initializers.uniform(1.0)
49+
weights_init: Callable = (
50+
nn.initializers.uniform(1.0)
51+
#hard_and.initialize_near_to_zero()
52+
)
4953
dtype: jax.numpy.dtype = jax.numpy.float32
5054

5155
@nn.compact
5256
def __call__(self, x):
5357
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
5458
weights = self.param(
55-
'bit_weights', self.weights_init, weights_shape, self.dtype)
56-
x = jax.numpy.asarray(x, self.dtype)
59+
"bit_weights", self.weights_init, weights_shape, self.dtype
60+
)
61+
x = jax.numpy.asarray(x, self.dtype)
5762
return soft_xor_layer(weights, x)
5863

5964

@@ -64,7 +69,8 @@ class HardXorLayer(nn.Module):
6469
def __call__(self, x):
6570
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
6671
weights = self.param(
67-
'bit_weights', nn.initializers.constant(True), weights_shape)
72+
"bit_weights", nn.initializers.constant(True), weights_shape
73+
)
6874
return hard_xor_layer(weights, x)
6975

7076

@@ -74,14 +80,18 @@ def __init__(self, layer_size):
7480
self.hard_xor_layer = HardXorLayer(self.layer_size)
7581

7682
def __call__(self, x):
77-
jaxpr = symbolic_generation.make_symbolic_flax_jaxpr(
78-
self.hard_xor_layer, x)
83+
jaxpr = symbolic_generation.make_symbolic_flax_jaxpr(self.hard_xor_layer, x)
7984
return symbolic_generation.symbolic_expression(jaxpr, x)
8085

8186

8287
xor_layer = neural_logic_net.select(
8388
lambda layer_size, weights_init=nn.initializers.uniform(
84-
1.0), dtype=jax.numpy.float32: SoftXorLayer(layer_size, weights_init, dtype),
89+
1.0
90+
), dtype=jax.numpy.float32: SoftXorLayer(layer_size, weights_init, dtype),
91+
lambda layer_size, weights_init=nn.initializers.constant(
92+
True
93+
), dtype=jax.numpy.float32: HardXorLayer(layer_size),
8594
lambda layer_size, weights_init=nn.initializers.constant(
86-
True), dtype=jax.numpy.float32: HardXorLayer(layer_size),
87-
lambda layer_size, weights_init=nn.initializers.constant(True), dtype=jax.numpy.float32: SymbolicXorLayer(layer_size))
95+
True
96+
), dtype=jax.numpy.float32: SymbolicXorLayer(layer_size),
97+
)

tests/test_mnist.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,20 @@
1313
from matplotlib import pyplot as plt
1414
from tqdm import tqdm
1515

16-
from neurallogic import (hard_and, hard_majority, hard_not, hard_or, hard_xor, harden,
17-
harden_layer, neural_logic_net, real_encoder)
16+
from neurallogic import (
17+
hard_and,
18+
hard_majority,
19+
hard_not,
20+
hard_or,
21+
hard_xor,
22+
harden,
23+
harden_layer,
24+
neural_logic_net,
25+
real_encoder,
26+
)
1827

1928
# Uncomment to debug NaNs
20-
#config.update("jax_debug_nans", True)
29+
# config.update("jax_debug_nans", True)
2130

2231
"""
2332
MNIST test.
@@ -44,15 +53,18 @@ def nln(type, x, width):
4453
return x
4554
"""
4655

56+
4757
def nln(type, x):
4858
num_classes = 10
4959

50-
x = hard_or.or_layer(type)(1800, nn.initializers.uniform(1.0), dtype=jax.numpy.float16)(x)
60+
x = hard_or.or_layer(type)(
61+
1800, nn.initializers.uniform(1.0), dtype=jax.numpy.float16
62+
)(x)
5163
x = hard_not.not_layer(type)(1, dtype=jax.numpy.float16)(x)
5264
x = x.ravel()
53-
x = harden_layer.harden_layer(type)(x)
54-
x = x.reshape((num_classes, int(x.shape[0] / num_classes)))
55-
x = x.sum(-1)
65+
x = harden_layer.harden_layer(type)(x)
66+
x = x.reshape((num_classes, int(x.shape[0] / num_classes)))
67+
x = x.sum(-1)
5668
return x
5769

5870

@@ -129,13 +141,13 @@ def get_datasets():
129141
train_ds = tfds.as_numpy(ds_builder.as_dataset(split="train", batch_size=-1))
130142
test_ds = tfds.as_numpy(ds_builder.as_dataset(split="test", batch_size=-1))
131143
# XXXX
132-
train_ds["image"] = (jnp.float32(train_ds["image"]) / 255.0)
133-
test_ds["image"] = (jnp.float32(test_ds["image"]) / 255.0)
144+
train_ds["image"] = jnp.float32(train_ds["image"]) / 255.0
145+
test_ds["image"] = jnp.float32(test_ds["image"]) / 255.0
134146
# TODO: we don't need to do this even when we don't use the real encoder
135147
# Use grayscale information
136148
# Convert the floating point values in [0,1] to binary values in {0,1}
137-
#train_ds["image"] = jnp.round(train_ds["image"])
138-
#test_ds["image"] = jnp.round(test_ds["image"])
149+
# train_ds["image"] = jnp.round(train_ds["image"])
150+
# test_ds["image"] = jnp.round(test_ds["image"])
139151
return train_ds, test_ds
140152

141153

@@ -165,23 +177,16 @@ def create_train_state(net, rng, config):
165177
# for NLN
166178
mock_input = jnp.ones([1, 28 * 28])
167179
soft_weights = net.init(rng, mock_input)["params"]
168-
#tx = optax.sgd(config.learning_rate, config.momentum)
169-
#tx = optax.noisy_sgd(config.learning_rate, config.momentum)
180+
# tx = optax.sgd(config.learning_rate, config.momentum)
181+
# tx = optax.noisy_sgd(config.learning_rate, config.momentum)
170182
tx = optax.yogi(config.learning_rate)
171183
return train_state.TrainState.create(apply_fn=net.apply, params=soft_weights, tx=tx)
172184

173185

174186
def train_and_evaluate(
175187
net, datasets, config: ml_collections.ConfigDict, workdir: str
176188
) -> train_state.TrainState:
177-
"""Execute model training and evaluation loop.
178-
Args:
179-
config: Hyperparameter configuration for training and evaluation.
180-
workdir: Directory where the tensorboard summaries are written to.
181-
Returns:
182-
The train state (which includes the `.params`).
183-
"""
184-
train_ds, test_ds = datasets
189+
train_dataset, test_dataset = datasets
185190
rng = jax.random.PRNGKey(0)
186191

187192
summary_writer = tensorboard.SummaryWriter(workdir)
@@ -193,10 +198,10 @@ def train_and_evaluate(
193198
for epoch in range(1, config.num_epochs + 1):
194199
rng, input_rng = jax.random.split(rng)
195200
state, train_loss, train_accuracy = train_epoch(
196-
state, train_ds, config.batch_size, input_rng
201+
state, train_dataset, config.batch_size, input_rng
197202
)
198203
_, test_loss, test_accuracy = apply_model_with_grad(
199-
state, test_ds["image"], test_ds["label"]
204+
state, test_dataset["image"], test_dataset["label"]
200205
)
201206

202207
print(
@@ -219,13 +224,13 @@ def get_config():
219224
# config for CNN
220225
config.learning_rate = 0.01
221226
# config for NLN
222-
#config.learning_rate = 0.1
227+
# config.learning_rate = 0.1
223228
config.learning_rate = 0.01
224229

225230
# Always commit with num_epochs = 1 for short test time
226231
config.momentum = 0.9
227232
config.batch_size = 128
228-
#config.num_epochs = 2
233+
# config.num_epochs = 2
229234
config.num_epochs = 1000
230235
return config
231236

@@ -290,6 +295,7 @@ def check_symbolic(nets, datasets, trained_state):
290295
symbolic_output = symbolic.apply({"params": symbolic_weights}, symbolic_input)
291296
print("symbolic_output", symbolic_output[0][:10000])
292297

298+
293299
@pytest.mark.skip(reason="temporarily off")
294300
def test_mnist():
295301
# Make sure tf does not allocate gpu memory.
@@ -311,13 +317,13 @@ def test_mnist():
311317

312318
print(soft.tabulate(jax.random.PRNGKey(0), train_ds["image"][0:1]))
313319
# TODO: fix the size of this
314-
#print(hard.tabulate(jax.random.PRNGKey(0), harden.harden(train_ds["image"][0:1])))
320+
# print(hard.tabulate(jax.random.PRNGKey(0), harden.harden(train_ds["image"][0:1])))
315321

316322
# Train and evaluate the model.
317323
trained_state = train_and_evaluate(
318324
soft, (train_ds, test_ds), config=config, workdir="./mnist_metrics"
319325
)
320326

321327
# Check symbolic net
322-
#_, hard, symbolic = neural_logic_net.net(lambda type, x: nln(type, x))
323-
#check_symbolic((soft, hard, symbolic), (train_ds, test_ds), trained_state)
328+
# _, hard, symbolic = neural_logic_net.net(lambda type, x: nln(type, x))
329+
# check_symbolic((soft, hard, symbolic), (train_ds, test_ds), trained_state)

tests/test_noisy_xor.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
from pathlib import Path
2+
import ml_collections
3+
import numpy
4+
import optax
5+
from flax.training import train_state
6+
from flax import linen as nn
7+
import jax
8+
9+
from neurallogic import (
10+
neural_logic_net,
11+
hard_not,
12+
hard_or,
13+
hard_and,
14+
hard_xor,
15+
hard_majority,
16+
harden_layer,
17+
)
18+
19+
num_features = 12
20+
num_classes = 2
21+
22+
23+
def get_data():
24+
# Create a path to the data directory
25+
data_dir = Path(__file__).parent.parent / "tests" / "data"
26+
# Load the training data
27+
training_data = numpy.loadtxt(data_dir / "NoisyXORTrainingData.txt").astype(
28+
dtype=numpy.int32
29+
)
30+
# Load the test data
31+
test_data = numpy.loadtxt(data_dir / "NoisyXORTestData.txt").astype(
32+
dtype=numpy.int32
33+
)
34+
return training_data, test_data
35+
36+
37+
# 89% test accuracy
38+
def nln_89(type, x):
39+
x = hard_and.and_layer(type)(20)(x)
40+
x = hard_not.not_layer(type)(5)(x)
41+
x = x.ravel()
42+
########################################################
43+
x = harden_layer.harden_layer(type)(x)
44+
x = x.reshape((num_classes, int(x.shape[0] / num_classes)))
45+
x = x.sum(-1)
46+
return x
47+
48+
49+
# 100% test accuracy
50+
def nln(type, x):
51+
x = hard_and.and_layer(type)(20)(x)
52+
x = hard_not.not_layer(type)(4)(x)
53+
x = x.ravel()
54+
########################################################
55+
x = harden_layer.harden_layer(type)(x)
56+
x = x.reshape((num_classes, int(x.shape[0] / num_classes)))
57+
x = x.sum(-1)
58+
return x
59+
60+
61+
def batch_nln(type, x):
62+
return jax.vmap(lambda x: nln(type, x))(x)
63+
64+
65+
def create_train_state(net, rng, config):
66+
mock_input = jax.numpy.ones([1, num_features])
67+
soft_weights = net.init(rng, mock_input)["params"]
68+
# tx = optax.sgd(config.learning_rate, config.momentum)
69+
tx = optax.yogi(config.learning_rate)
70+
return train_state.TrainState.create(apply_fn=net.apply, params=soft_weights, tx=tx)
71+
72+
73+
@jax.jit
74+
def update_model(state, grads):
75+
return state.apply_gradients(grads=grads)
76+
77+
78+
@jax.jit
79+
def apply_model_with_grad(state, features, labels):
80+
def loss_fn(params):
81+
logits = state.apply_fn({"params": params}, features)
82+
one_hot = jax.nn.one_hot(labels, num_classes)
83+
loss = jax.numpy.mean(
84+
optax.softmax_cross_entropy(logits=logits, labels=one_hot)
85+
)
86+
return loss, logits
87+
88+
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
89+
(loss, logits), grads = grad_fn(state.params)
90+
accuracy = jax.numpy.mean(jax.numpy.argmax(logits, -1) == labels)
91+
return grads, loss, accuracy
92+
93+
94+
def train_epoch(state, features, labels, batch_size, rng):
95+
train_ds_size = len(features)
96+
steps_per_epoch = train_ds_size // batch_size
97+
98+
perms = jax.random.permutation(rng, len(features))
99+
perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch
100+
perms = perms.reshape((steps_per_epoch, batch_size))
101+
102+
epoch_loss = []
103+
epoch_accuracy = []
104+
105+
for perm in perms:
106+
batch_features = features[perm, ...]
107+
batch_labels = labels[perm, ...]
108+
grads, loss, accuracy = apply_model_with_grad(
109+
state, batch_features, batch_labels
110+
)
111+
state = update_model(state, grads)
112+
epoch_loss.append(loss)
113+
epoch_accuracy.append(accuracy)
114+
train_loss = numpy.mean(epoch_loss)
115+
train_accuracy = numpy.mean(epoch_accuracy)
116+
return state, train_loss, train_accuracy
117+
118+
119+
def train_and_evaluate(net, datasets, config: ml_collections.ConfigDict):
120+
training_data, test_data = datasets
121+
x_training = training_data[:, 0:num_features] # Input features
122+
y_training = training_data[:, num_features] # Target value
123+
x_test = test_data[:, 0:num_features] # Input features
124+
y_test = test_data[:, num_features] # Target value
125+
126+
rng = jax.random.PRNGKey(0)
127+
print(net.tabulate(rng, x_training[0:1]))
128+
129+
rng, init_rng = jax.random.split(rng)
130+
state = create_train_state(net, init_rng, config)
131+
132+
best_test_accuracy = 0.0
133+
for epoch in range(1, config.num_epochs + 1):
134+
rng, input_rng = jax.random.split(rng)
135+
state, train_loss, train_accuracy = train_epoch(
136+
state, x_training, y_training, config.batch_size, input_rng
137+
)
138+
_, test_loss, test_accuracy = apply_model_with_grad(state, x_test, y_test)
139+
if test_accuracy > best_test_accuracy:
140+
best_test_accuracy = test_accuracy
141+
142+
print(
143+
"epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f"
144+
% (epoch, train_loss, train_accuracy * 100, test_loss, test_accuracy * 100)
145+
)
146+
print(f"best_test_accuracy: {best_test_accuracy * 100:.2f}")
147+
148+
return state
149+
150+
151+
def get_config():
152+
config = ml_collections.ConfigDict()
153+
config.learning_rate = 0.01
154+
config.momentum = 0.9
155+
config.batch_size = 256
156+
config.num_epochs = 1000
157+
return config
158+
159+
160+
def test_noisy_xor():
161+
soft, hard, _ = neural_logic_net.net(lambda type, x: batch_nln(type, x))
162+
training_data, test_data = get_data()
163+
trained_state = train_and_evaluate(soft, (training_data, test_data), get_config())

0 commit comments

Comments
 (0)