Skip to content

Commit 87671ff

Browse files
committed
rename
1 parent ce5195d commit 87671ff

File tree

4 files changed

+210
-95
lines changed

4 files changed

+210
-95
lines changed

neurallogic/hard_bit.py

Lines changed: 0 additions & 29 deletions
This file was deleted.

neurallogic/real_encoder.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import jax
2+
3+
4+
def soft_real_encoder(t: float, x: float) -> float:
5+
# x should be in [0, 1]
6+
t = jax.numpy.clip(t, 0.0, 1.0)
7+
return jax.numpy.where(
8+
x == t,
9+
0.5,
10+
jax.numpy.where(
11+
x < t,
12+
(1.0 / (2.0 * t)) * x,
13+
(1.0 / (2.0 * (1.0 - t))) * (x + 1.0 - 2.0 * t),
14+
),
15+
)
16+
17+
18+
def hard_real_encoder(t: float, x: float) -> bool:
19+
# t and x must be floats
20+
return jax.numpy.where(soft_real_encoder(t, x) > 0.5, True, False)
21+
22+
23+
soft_real_encoder_neuron = jax.vmap(soft_real_encoder, in_axes=(0, None))
24+
25+
hard_real_encoder_neuron = jax.vmap(hard_real_encoder, in_axes=(0, None))
26+
27+
soft_real_encoder_layer = jax.vmap(soft_real_encoder_neuron, (0, 0), 0)
28+
29+
hard_real_encoder_layer = jax.vmap(hard_real_encoder_neuron, (0, 0), 0)
30+
31+

tests/test_mnist.py

Lines changed: 69 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
from flax.metrics import tensorboard
1010
from flax.training import train_state
1111
import ml_collections
12-
from neurallogic import (hard_not, hard_or, harden, harden_layer,
13-
neural_logic_net)
12+
from neurallogic import hard_not, hard_or, harden, harden_layer, neural_logic_net
1413
import optax
1514

1615

@@ -23,9 +22,8 @@
2322

2423

2524
def nln(type, x, width):
26-
x = hard_or.or_layer(type)(width, nn.initializers.uniform(
27-
1.0), dtype=jnp.float32)(x) # >=1700 need for >98% accuracy
28-
x = hard_not.not_layer(type)(10, dtype=jnp.float32)(x)
25+
x = hard_or.or_layer(type)(width, nn.initializers.uniform(1.0))(x)
26+
x = hard_not.not_layer(type)(10)(x)
2927
x = x.ravel() # flatten the outputs of the not layer
3028
# harden the outputs of the not layer
3129
x = harden_layer.harden_layer(type)(x)
@@ -59,11 +57,11 @@ def __call__(self, x):
5957
@jax.jit
6058
def apply_model_with_grad(state, images, labels):
6159
"""Computes gradients, loss and accuracy for a single batch."""
60+
6261
def loss_fn(params):
63-
logits = state.apply_fn({'params': params}, images)
62+
logits = state.apply_fn({"params": params}, images)
6463
one_hot = jax.nn.one_hot(labels, 10)
65-
loss = jnp.mean(optax.softmax_cross_entropy(
66-
logits=logits, labels=one_hot))
64+
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
6765
return loss, logits
6866

6967
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
@@ -79,21 +77,20 @@ def update_model(state, grads):
7977

8078
def train_epoch(state, train_ds, batch_size, rng):
8179
"""Train for a single epoch."""
82-
train_ds_size = len(train_ds['image'])
80+
train_ds_size = len(train_ds["image"])
8381
steps_per_epoch = train_ds_size // batch_size
8482

85-
perms = jax.random.permutation(rng, len(train_ds['image']))
86-
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
83+
perms = jax.random.permutation(rng, len(train_ds["image"]))
84+
perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch
8785
perms = perms.reshape((steps_per_epoch, batch_size))
8886

8987
epoch_loss = []
9088
epoch_accuracy = []
9189

9290
for perm in perms:
93-
batch_images = train_ds['image'][perm, ...]
94-
batch_labels = train_ds['label'][perm, ...]
95-
grads, loss, accuracy = apply_model_with_grad(
96-
state, batch_images, batch_labels)
91+
batch_images = train_ds["image"][perm, ...]
92+
batch_labels = train_ds["label"][perm, ...]
93+
grads, loss, accuracy = apply_model_with_grad(state, batch_images, batch_labels)
9794
state = update_model(state, grads)
9895
epoch_loss.append(loss)
9996
epoch_accuracy.append(accuracy)
@@ -103,24 +100,23 @@ def train_epoch(state, train_ds, batch_size, rng):
103100

104101

105102
def get_datasets():
106-
ds_builder = tfds.builder('mnist')
103+
ds_builder = tfds.builder("mnist")
107104
ds_builder.download_and_prepare()
108-
train_ds = tfds.as_numpy(
109-
ds_builder.as_dataset(split='train', batch_size=-1))
110-
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
111-
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
112-
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
105+
train_ds = tfds.as_numpy(ds_builder.as_dataset(split="train", batch_size=-1))
106+
test_ds = tfds.as_numpy(ds_builder.as_dataset(split="test", batch_size=-1))
107+
train_ds["image"] = jnp.float32(train_ds["image"]) / 255.0
108+
test_ds["image"] = jnp.float32(test_ds["image"]) / 255.0
113109
# Convert the floating point values in [0,1] to binary values in {0,1}
114-
train_ds['image'] = jnp.round(train_ds['image'])
115-
test_ds['image'] = jnp.round(test_ds['image'])
110+
train_ds["image"] = jnp.round(train_ds["image"])
111+
test_ds["image"] = jnp.round(test_ds["image"])
116112
return train_ds, test_ds
117113

118114

119115
def show_img(img, ax=None, title=None):
120116
"""Shows a single image."""
121117
if ax is None:
122118
ax = plt.gca()
123-
ax.imshow(img.reshape(28, 28), cmap='gray')
119+
ax.imshow(img.reshape(28, 28), cmap="gray")
124120
ax.set_xticks([])
125121
ax.set_yticks([])
126122
if title:
@@ -129,7 +125,7 @@ def show_img(img, ax=None, title=None):
129125

130126
def show_img_grid(imgs, titles):
131127
"""Shows a grid of images."""
132-
n = int(np.ceil(len(imgs)**.5))
128+
n = int(np.ceil(len(imgs) ** 0.5))
133129
_, axs = plt.subplots(n, n, figsize=(3 * n, 3 * n))
134130
for i, (img, title) in enumerate(zip(imgs, titles)):
135131
show_img(img, axs[i // n][i % n], title)
@@ -141,13 +137,14 @@ def create_train_state(net, rng, config):
141137
# mock_input = jnp.ones([1, 28, 28, 1])
142138
# for NLN
143139
mock_input = jnp.ones([1, 28 * 28])
144-
soft_weights = net.init(rng, mock_input)['params']
140+
soft_weights = net.init(rng, mock_input)["params"]
145141
tx = optax.sgd(config.learning_rate, config.momentum)
146142
return train_state.TrainState.create(apply_fn=net.apply, params=soft_weights, tx=tx)
147143

148144

149-
def train_and_evaluate(net, datasets, config: ml_collections.ConfigDict,
150-
workdir: str) -> train_state.TrainState:
145+
def train_and_evaluate(
146+
net, datasets, config: ml_collections.ConfigDict, workdir: str
147+
) -> train_state.TrainState:
151148
"""Execute model training and evaluation loop.
152149
Args:
153150
config: Hyperparameter configuration for training and evaluation.
@@ -166,21 +163,22 @@ def train_and_evaluate(net, datasets, config: ml_collections.ConfigDict,
166163

167164
for epoch in range(1, config.num_epochs + 1):
168165
rng, input_rng = jax.random.split(rng)
169-
state, train_loss, train_accuracy = train_epoch(state, train_ds,
170-
config.batch_size,
171-
input_rng)
172-
_, test_loss, test_accuracy = apply_model_with_grad(state, test_ds['image'],
173-
test_ds['label'])
166+
state, train_loss, train_accuracy = train_epoch(
167+
state, train_ds, config.batch_size, input_rng
168+
)
169+
_, test_loss, test_accuracy = apply_model_with_grad(
170+
state, test_ds["image"], test_ds["label"]
171+
)
174172

175173
print(
176-
'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f'
177-
% (epoch, train_loss, train_accuracy * 100, test_loss,
178-
test_accuracy * 100))
174+
"epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f"
175+
% (epoch, train_loss, train_accuracy * 100, test_loss, test_accuracy * 100)
176+
)
179177

180-
summary_writer.scalar('train_loss', train_loss, epoch)
181-
summary_writer.scalar('train_accuracy', train_accuracy, epoch)
182-
summary_writer.scalar('test_loss', test_loss, epoch)
183-
summary_writer.scalar('test_accuracy', test_accuracy, epoch)
178+
summary_writer.scalar("train_loss", train_loss, epoch)
179+
summary_writer.scalar("train_accuracy", train_accuracy, epoch)
180+
summary_writer.scalar("test_loss", test_loss, epoch)
181+
summary_writer.scalar("test_accuracy", test_accuracy, epoch)
184182

185183
return state
186184

@@ -197,13 +195,13 @@ def get_config():
197195
# Always commit with num_epochs = 1 for short test time
198196
config.momentum = 0.9
199197
config.batch_size = 128
200-
config.num_epochs = 2
198+
config.num_epochs = 1000
201199
return config
202200

203201

204202
def apply_hard_model(state, image, label):
205203
def logits_fn(params):
206-
return state.apply_fn({'params': params}, image)
204+
return state.apply_fn({"params": params}, image)
207205

208206
logits = logits_fn(state.params)
209207
if isinstance(logits, list):
@@ -224,33 +222,41 @@ def check_symbolic(nets, datasets, trained_state):
224222
_, test_ds = datasets
225223
_, hard, symbolic = nets
226224
_, test_loss, test_accuracy = apply_model_with_grad(
227-
trained_state, test_ds['image'], test_ds['label'])
228-
print('soft_net: final test_loss: %.4f, final test_accuracy: %.2f' %
229-
(test_loss, test_accuracy * 100))
225+
trained_state, test_ds["image"], test_ds["label"]
226+
)
227+
print(
228+
"soft_net: final test_loss: %.4f, final test_accuracy: %.2f"
229+
% (test_loss, test_accuracy * 100)
230+
)
230231
hard_weights = harden.hard_weights(trained_state.params)
231232
hard_trained_state = train_state.TrainState.create(
232-
apply_fn=hard.apply, params=hard_weights, tx=optax.sgd(1.0, 1.0))
233-
hard_input = harden.harden(test_ds['image'])
233+
apply_fn=hard.apply, params=hard_weights, tx=optax.sgd(1.0, 1.0)
234+
)
235+
hard_input = harden.harden(test_ds["image"])
234236
hard_test_accuracy = apply_hard_model_to_images(
235-
hard_trained_state, hard_input, test_ds['label'])
236-
print('hard_net: final test_accuracy: %.2f' % (hard_test_accuracy * 100))
237+
hard_trained_state, hard_input, test_ds["label"]
238+
)
239+
print("hard_net: final test_accuracy: %.2f" % (hard_test_accuracy * 100))
237240
assert np.isclose(test_accuracy, hard_test_accuracy, atol=0.0001)
241+
# TODO: activate these checks
238242
if False:
239243
# It takes too long to compute this
240244
symbolic_weights = harden.symbolic_weights(trained_state.params)
241245
symbolic_trained_state = train_state.TrainState.create(
242-
apply_fn=symbolic.apply, params=symbolic_weights, tx=optax.sgd(1.0, 1.0))
246+
apply_fn=symbolic.apply, params=symbolic_weights, tx=optax.sgd(1.0, 1.0)
247+
)
243248
symbolic_input = hard_input.tolist()
244249
symbolic_test_accuracy = apply_hard_model_to_images(
245-
symbolic_trained_state, symbolic_input, test_ds['label'])
246-
print('symbolic_net: final test_accuracy: %.2f' %
247-
(symbolic_test_accuracy * 100))
248-
assert (np.isclose(test_accuracy, symbolic_test_accuracy, atol=0.0001))
250+
symbolic_trained_state, symbolic_input, test_ds["label"]
251+
)
252+
print(
253+
"symbolic_net: final test_accuracy: %.2f" % (symbolic_test_accuracy * 100)
254+
)
255+
assert np.isclose(test_accuracy, symbolic_test_accuracy, atol=0.0001)
249256
if False:
250257
# CPU and GPU give different results, so we can't easily regress on a static symbolic expression
251258
symbolic_input = [f"x{i}" for i in range(len(hard_input[0].tolist()))]
252-
symbolic_output = symbolic.apply(
253-
{'params': symbolic_weights}, symbolic_input)
259+
symbolic_output = symbolic.apply({"params": symbolic_weights}, symbolic_input)
254260
print("symbolic_output", symbolic_output[0][:10000])
255261

256262

@@ -263,23 +269,20 @@ def test_mnist():
263269

264270
# Define the model.
265271
# soft = CNN()
266-
width = 100
267-
soft, _, _ = neural_logic_net.net(
268-
lambda type, x: batch_nln(type, x, width))
272+
width = 1000
273+
soft, _, _ = neural_logic_net.net(lambda type, x: batch_nln(type, x, width))
269274

270275
# Get the MNIST dataset.
271276
train_ds, test_ds = get_datasets()
272277
# If we're using a NLN then flatten the images
273-
train_ds["image"] = jnp.reshape(
274-
train_ds["image"], (train_ds["image"].shape[0], -1))
275-
test_ds["image"] = jnp.reshape(
276-
test_ds["image"], (test_ds["image"].shape[0], -1))
278+
train_ds["image"] = jnp.reshape(train_ds["image"], (train_ds["image"].shape[0], -1))
279+
test_ds["image"] = jnp.reshape(test_ds["image"], (test_ds["image"].shape[0], -1))
277280

278281
# Train and evaluate the model.
279282
trained_state = train_and_evaluate(
280-
soft, (train_ds, test_ds), config=config, workdir="./mnist_metrics")
283+
soft, (train_ds, test_ds), config=config, workdir="./mnist_metrics"
284+
)
281285

282286
# Check symbolic net
283-
_, hard, symbolic = neural_logic_net.net(
284-
lambda type, x: nln(type, x, width))
287+
_, hard, symbolic = neural_logic_net.net(lambda type, x: nln(type, x, width))
285288
check_symbolic((soft, hard, symbolic), (train_ds, test_ds), trained_state)

0 commit comments

Comments
 (0)