Skip to content

Commit 16d23db

Browse files
committed
get test_network working
1 parent dc48979 commit 16d23db

File tree

3 files changed

+24
-11
lines changed

3 files changed

+24
-11
lines changed

neurallogic/symbolic_primitives.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,11 @@ def binary_infix_operator(operator: str, a: str, b: float):
157157
def binary_infix_operator(operator: str, a: numpy.ndarray, b: jax.numpy.ndarray):
158158
return binary_infix_operator(operator, a, numpy.array(b))
159159

160+
@dispatch
161+
def binary_infix_operator(operator: str, a: bool, b: str):
162+
return binary_infix_operator(operator, str(a), b)
163+
164+
160165

161166
def all_concrete_values(data):
162167
if isinstance(data, str):

tests/test_mnist.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,11 @@ def nln(type, x, width):
2626
x = hard_or.or_layer(type)(width, nn.initializers.uniform(
2727
1.0), dtype=jnp.float32)(x) # >=1700 need for >98% accuracy
2828
x = hard_not.not_layer(type)(10, dtype=jnp.float32)(x)
29-
x = primitives.nl_ravel(type)(x) # flatten the outputs of the not layer
29+
x = x.ravel() # flatten the outputs of the not layer
3030
# harden the outputs of the not layer
3131
x = harden_layer.harden_layer(type)(x)
32-
x = primitives.nl_reshape(type)((10, width))(
33-
x) # reshape to 10 ports, 100 bits each
34-
x = primitives.nl_sum(type)(-1)(x) # sum the 100 bits in each port
32+
x = x.reshape((10, width)) # reshape to 10 ports, 100 bits each
33+
x = x.sum(-1) # sum the 100 bits in each port
3534
return x
3635

3736

tests/test_network.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from neurallogic import (hard_and, hard_not, hard_or, harden, harden_layer,
2-
neural_logic_net, primitives)
2+
neural_logic_net, symbolic_generation)
33
from jax import random
44
from flax.training import train_state
55
from flax import linen as nn
66
import optax
77
import jax.numpy as jnp
88
import jax
99
from jax.config import config
10+
import numpy
11+
1012
config.update("jax_enable_x64", True)
1113

1214

@@ -17,7 +19,7 @@ def test_net(type, x):
1719
x = hard_and.and_layer(type)(
1820
4, nn.initializers.uniform(1.0), jnp.float64)(x)
1921
x = hard_not.not_layer(type)(1, dtype=jnp.float64)(x)
20-
x = primitives.nl_ravel(type)(x)
22+
x = x.ravel()
2123
x = harden_layer.harden_layer(type)(x)
2224
return x
2325

@@ -55,20 +57,27 @@ def test_net(type, x):
5557

5658
# Test that the and layer (both soft and hard variants) correctly predicts y
5759
for input, expected in zip(x, y):
60+
5861
input = jnp.array(input)
5962
expected = jnp.array(expected)
6063
soft_result = soft.apply(best_weights, input)
6164
assert jnp.allclose(soft_result, expected)
65+
6266
hard_input = harden.harden(input)
6367
hard_expected = harden.harden(expected)
6468
hard_weights = harden.hard_weights(best_weights)
6569
hard_result = hard.apply(hard_weights, hard_input)
6670
assert jnp.array_equal(hard_result, hard_expected)
67-
symbolic_weights = harden.symbolic_weights(best_weights)
68-
symbolic_result = symbolic.apply(symbolic_weights, hard_input.tolist())
71+
72+
symbolic_weights = symbolic_generation.make_symbolic(hard_weights)
73+
symbolic_result = symbolic.apply(symbolic_weights, hard_input)
74+
symbolic_result = symbolic_generation.eval_symbolic_expression(symbolic_result)
6975
assert jnp.array_equal(symbolic_result, hard_expected)
76+
7077
symbolic_input = ['x1', 'x2']
71-
symbolic_expected = ['(not(((((x1 and False) or (x2 and True)) or not(True)) and (((x1 and True) or (x2 and False)) or not(False)) and (((x1 and False) or (x2 and True)) or not(True)) and (((x1 and False) or (x2 and False)) or not(True)) and (((x1 and True) or (x2 and False)) or not(False)) and (((x1 and True) or (x2 and True)) or not(True)) and (((x1 and True) or (x2 and False)) or not(True)) and (((x1 and True) or (x2 and True)) or not(True)) and (((x1 and False) or (x2 and True)) or not(True)) and (((x1 and True) or (x2 and False)) or not(True)) and (((x1 and True) or (x2 and True)) or not(False)) and (((x1 and True) or (x2 and True)) or not(True)) and (((x1 and False) or (x2 and False)) or not(False)) and (((x1 and False) or (x2 and True)) or not(False)) and (((x1 and True) or (x2 and True)) or not(True)) and (((x1 and True) or (x2 and True)) or not(True))) ^ False))', '(not(((((x1 and False) or (x2 and True)) or not(True)) and (((x1 and True) or (x2 and False)) or not(True)) and (((x1 and False) or (x2 and True)) or not(False)) and (((x1 and False) or (x2 and False)) or not(True)) and (((x1 and True) or (x2 and False)) or not(False)) and (((x1 and True) or (x2 and True)) or not(False)) and (((x1 and True) or (x2 and False)) or not(False)) and (((x1 and True) or (x2 and True)) or not(True)) and (((x1 and False) or (x2 and True)) or not(False)) and (((x1 and True) or (x2 and False)) or not(False)) and (((x1 and True) or (x2 and True)) or not(True)) and (((x1 and True) or (x2 and True)) or not(True)) and (((x1 and False) or (x2 and False)) or not(False)) and (((x1 and False) or (x2 and True)) or not(False)) and (((x1 and True) or (x2 and True)) or not(False)) and (((x1 and True) or (x2 and True)) or not(False))) ^ True))',
72-
'(not(((((x1 and False) or (x2 and True)) or not(True)) and (((x1 and True) or (x2 and False)) or not(False)) and (((x1 and False) or (x2 and True)) or not(False)) and (((x1 and False) or (x2 and False)) or not(False)) and (((x1 and True) or (x2 and False)) or not(False)) and (((x1 and True) or (x2 and True)) or not(False)) and (((x1 and True) or (x2 and False)) or not(True)) and (((x1 and True) or (x2 and True)) or not(False)) and (((x1 and False) or (x2 and True)) or not(True)) and (((x1 and True) or (x2 and False)) or not(False)) and (((x1 and True) or (x2 and True)) or not(True)) and (((x1 and True) or (x2 and True)) or not(True)) and (((x1 and False) or (x2 and False)) or not(True)) and (((x1 and False) or (x2 and True)) or not(False)) and (((x1 and True) or (x2 and True)) or not(False)) and (((x1 and True) or (x2 and True)) or not(True))) ^ False))', '(not(((((x1 and False) or (x2 and True)) or not(False)) and (((x1 and True) or (x2 and False)) or not(False)) and (((x1 and False) or (x2 and True)) or not(False)) and (((x1 and False) or (x2 and False)) or not(True)) and (((x1 and True) or (x2 and False)) or not(False)) and (((x1 and True) or (x2 and True)) or not(True)) and (((x1 and True) or (x2 and False)) or not(True)) and (((x1 and True) or (x2 and True)) or not(False)) and (((x1 and False) or (x2 and True)) or not(False)) and (((x1 and True) or (x2 and False)) or not(True)) and (((x1 and True) or (x2 and True)) or not(False)) and (((x1 and True) or (x2 and True)) or not(False)) and (((x1 and False) or (x2 and False)) or not(True)) and (((x1 and False) or (x2 and True)) or not(True)) and (((x1 and True) or (x2 and True)) or not(True)) and (((x1 and True) or (x2 and True)) or not(True))) ^ True))']
78+
symbolic_expected = ['not((True and ((((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (True != 0)) != 0) or not((True != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (False != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (True != 0)) != 0) or not((True != 0))) and ((((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (False != 0)) != 0) or not((True != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (False != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (True != 0)) != 0) or not((True != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (False != 0)) != 0) or not((True != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (True != 0)) != 0) or not((True != 0))) and ((((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (True != 0)) != 0) or not((True != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (False != 0)) != 0) or not((True != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (True != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (True != 0)) != 0) or not((True != 0))) and ((((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (False != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (True != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (True != 0)) != 0) or not((True != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (True != 0)) != 0) or not((True != 0))) != 0) ^ (False != 0))',
79+
'not((True and ((((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (True != 0)) != 0) or not((True != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (False != 0)) != 0) or not((True != 0))) and ((((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (True != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (False != 0)) != 0) or not((True != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (False != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (True != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (False != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (True != 0)) != 0) or not((True != 0))) and ((((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (True != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (False != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (True != 0)) != 0) or not((True != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (True != 0)) != 0) or not((True != 0))) and ((((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (False != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (True != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (True != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (True != 0)) != 0) or not((False != 0))) != 0) ^ (True != 0))',
80+
'not((True and ((((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (True != 0)) != 0) or not((True != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (False != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (True != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (False != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (False != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (True != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (False != 0)) != 0) or not((True != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (True != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (True != 0)) != 0) or not((True != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (False != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (True != 0)) != 0) or not((True != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (True != 0)) != 0) or not((True != 0))) and ((((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (False != 0)) != 0) or not((True != 0))) and ((((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (True != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (True != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (True != 0)) != 0) or not((True != 0))) != 0) ^ (False != 0))',
81+
'not((True and ((((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (True != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (False != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (True != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (False != 0)) != 0) or not((True != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (False != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (True != 0)) != 0) or not((True != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (False != 0)) != 0) or not((True != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (True != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (True != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (False != 0)) != 0) or not((True != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (True != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (True != 0)) != 0) or not((False != 0))) and ((((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (False != 0)) != 0) or not((True != 0))) and ((((False or (x1 != 0) and (False != 0)) or (x2 != 0) and (True != 0)) != 0) or not((True != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (True != 0)) != 0) or not((True != 0))) and ((((False or (x1 != 0) and (True != 0)) or (x2 != 0) and (True != 0)) != 0) or not((True != 0))) != 0) ^ (True != 0))']
7382
symbolic_result = symbolic.apply(symbolic_weights, symbolic_input)
74-
assert symbolic_result == symbolic_expected
83+
assert numpy.array_equal(symbolic_result, symbolic_expected)

0 commit comments

Comments
 (0)