Skip to content

Commit edaac8f

Browse files
committed
get more of test_hard_not working
1 parent 4c4fbf2 commit edaac8f

File tree

4 files changed

+175
-131
lines changed

4 files changed

+175
-131
lines changed

neurallogic/hard_not.py

Lines changed: 8 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import jax
44
from flax import linen as nn
55

6-
from neurallogic import neural_logic_net
6+
from neurallogic import neural_logic_net, symbolic_generation
77

88

99
def soft_not(w: float, x: float) -> float:
@@ -18,57 +18,24 @@ def soft_not(w: float, x: float) -> float:
1818
return 1.0 - w + x * (2.0 * w - 1.0)
1919

2020

21-
@jax.jit
2221
def hard_not(w: bool, x: bool) -> bool:
23-
return ~(x ^ w)
24-
25-
26-
def symbolic_not(w, x):
27-
expression = f"(not({x} ^ {w}))"
28-
# Check if w is of type bool
29-
if isinstance(w, bool) and isinstance(x, bool):
30-
# We know the value of w and x, so we can evaluate the expression
31-
return eval(expression)
32-
# We don't know the value of w or x, so we return the expression
33-
return expression
22+
return jax.numpy.logical_not(jax.numpy.logical_xor(x, w))
3423

3524

3625
soft_not_neuron = jax.vmap(soft_not, 0, 0)
3726

3827
hard_not_neuron = jax.vmap(hard_not, 0, 0)
3928

4029

41-
def symbolic_not_neuron(w, x):
42-
# TODO: ensure that this implementation has the same generality over tensors as vmap
43-
if not isinstance(w, list):
44-
raise TypeError(f"Input {x} should be a list")
45-
if not isinstance(x, list):
46-
raise TypeError(f"Input {x} should be a list")
47-
return [symbolic_not(wi, xi) for wi, xi in zip(w, x)]
48-
4930

5031
soft_not_layer = jax.vmap(soft_not_neuron, (0, None), 0)
5132

5233
hard_not_layer = jax.vmap(hard_not_neuron, (0, None), 0)
5334

5435

55-
def symbolic_not_layer(w, x):
56-
# TODO: ensure that this implementation has the same generality over tensors as vmap
57-
if not isinstance(w, list):
58-
raise TypeError(f"Input {x} should be a list")
59-
if not isinstance(x, list):
60-
raise TypeError(f"Input {x} should be a list")
61-
return [symbolic_not_neuron(wi, x) for wi in w]
6236

6337

6438
class SoftNotLayer(nn.Module):
65-
"""
66-
A soft-bit NOT layer than transforms its inputs along the last dimension.
67-
68-
Attributes:
69-
layer_size: The number of neurons in the layer.
70-
weights_init: The initializer function for the weight matrix.
71-
"""
7239
layer_size: int
7340
weights_init: Callable = nn.initializers.uniform(1.0)
7441
dtype: jax.numpy.dtype = jax.numpy.float32
@@ -83,13 +50,6 @@ def __call__(self, x):
8350

8451

8552
class HardNotLayer(nn.Module):
86-
"""
87-
A hard-bit NOT layer that shadows the SoftNotLayer.
88-
This is a convenience class to make it easier to switch between soft and hard logic.
89-
90-
Attributes:
91-
layer_size: The number of neurons in the layer.
92-
"""
9353
layer_size: int
9454

9555
@nn.compact
@@ -100,22 +60,14 @@ def __call__(self, x):
10060
return hard_not_layer(weights, x)
10161

10262

103-
class SymbolicNotLayer(nn.Module):
104-
"""A symbolic NOT layer than transforms its inputs along the last dimension.
105-
Attributes:
106-
layer_size: The number of neurons in the layer.
107-
"""
108-
layer_size: int
63+
class SymbolicNotLayer:
64+
def __init__(self, layer_size):
65+
self.layer_size = layer_size
66+
self.hard_not_layer = HardNotLayer(self.layer_size)
10967

110-
@nn.compact
11168
def __call__(self, x):
112-
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
113-
weights = self.param(
114-
'weights', nn.initializers.constant(0.0), weights_shape)
115-
weights = weights.tolist()
116-
if not isinstance(x, list):
117-
raise TypeError(f"Input {x} should be a list")
118-
return symbolic_not_layer(weights, x)
69+
jaxpr = symbolic_generation.make_symbolic_flax_jaxpr(self.hard_not_layer, x)
70+
return symbolic_generation.symbolic_expression(jaxpr, x)
11971

12072

12173
not_layer = neural_logic_net.select(

neurallogic/symbolic_primitives.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ def binary_infix_operator(operator: str, a: numpy.ndarray, b: float, bracket: bo
155155
def binary_infix_operator(operator: str, a: str, b: float, bracket: bool = False):
156156
return binary_infix_operator(operator, a, str(b), bracket)
157157

158+
@dispatch
159+
def binary_infix_operator(operator: str, a: numpy.ndarray, b: jax.numpy.ndarray, bracket: bool = False):
160+
return binary_infix_operator(operator, a, numpy.array(b), bracket)
158161

159162

160163
def all_concrete_values(data):

tests/test_hard_and.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def test_net(type, x):
121121
return hard_and.and_layer(type)(4, nn.initializers.uniform(1.0))(x)
122122

123123
soft, hard, symbolic = neural_logic_net.net(test_net)
124-
soft_weights = soft.init(random.PRNGKey(0), [0.0, 0.0])
124+
weights = soft.init(random.PRNGKey(0), [0.0, 0.0])
125125

126126
x = [
127127
[1.0, 1.0],
@@ -141,16 +141,16 @@ def test_net(type, x):
141141
# Train the and layer
142142
tx = optax.sgd(0.1)
143143
state = train_state.TrainState.create(apply_fn=jax.vmap(
144-
soft.apply, in_axes=(None, 0)), params=soft_weights, tx=tx)
144+
soft.apply, in_axes=(None, 0)), params=weights, tx=tx)
145145
grad_fn = jax.jit(jax.value_and_grad(lambda params, x,
146146
y: jax.numpy.mean((state.apply_fn(params, x) - y) ** 2)))
147147
for epoch in range(1, 100):
148148
loss, grads = grad_fn(state.params, input, output)
149149
state = state.apply_gradients(grads=grads)
150150

151151
# Test that the and layer (both soft and hard variants) correctly predicts y
152-
soft_weights = state.params
153-
hard_weights = harden.hard_weights(soft_weights)
152+
weights = state.params
153+
hard_weights = harden.hard_weights(weights)
154154

155155
for input, expected in zip(x, y):
156156
hard_input = harden.harden(jax.numpy.array(input))
@@ -171,11 +171,11 @@ def test_net(type, x):
171171

172172
# Compute soft result
173173
soft_input = jax.numpy.array([0.6, 0.45])
174-
soft_weights = soft.init(random.PRNGKey(0), soft_input)
175-
soft_result = soft.apply(soft_weights, numpy.array(soft_input))
174+
weights = soft.init(random.PRNGKey(0), soft_input)
175+
soft_result = soft.apply(weights, numpy.array(soft_input))
176176

177177
# Compute hard result
178-
hard_weights = harden.hard_weights(soft_weights)
178+
hard_weights = harden.hard_weights(weights)
179179
hard_input = harden.harden(soft_input)
180180
hard_result = hard.apply(hard_weights, numpy.array(hard_input))
181181
# Check that the hard result is the same as the soft result

0 commit comments

Comments
 (0)