Skip to content

Commit c1774da

Browse files
authored
Merge pull request #53 from github/remove-old-symbolic-eval
New symbolic evaluator
2 parents bd9f84e + 967d083 commit c1774da

21 files changed

+1007
-1075
lines changed

neurallogic/hard_and.py

Lines changed: 30 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,49 @@
1-
from functools import reduce
2-
from typing import Callable
1+
from typing import Any
32

3+
import numpy
44
import jax
55
from flax import linen as nn
6+
from typing import Callable
7+
68

7-
from neurallogic import neural_logic_net
9+
from neurallogic import neural_logic_net, symbolic_generation
810

911

1012
def soft_and_include(w: float, x: float) -> float:
1113
"""
1214
w > 0.5 implies the and operation is active, else inactive
1315
1416
Assumes x is in [0, 1]
15-
17+
1618
Corresponding hard logic: x OR ! w
1719
"""
1820
w = jax.numpy.clip(w, 0.0, 1.0)
1921
return jax.numpy.maximum(x, 1.0 - w)
2022

21-
@jax.jit
22-
def hard_and_include(w: bool, x: bool) -> bool:
23-
return x | ~w
2423

25-
def symbolic_and_include(w, x):
26-
expression = f"({x} or not({w}))"
27-
# Check if w is of type bool
28-
if isinstance(w, bool) and isinstance(x, bool):
29-
# We know the value of w and x, so we can evaluate the expression
30-
return eval(expression)
31-
# We don't know the value of w or x, so we return the expression
32-
return expression
24+
25+
def hard_and_include(w, x):
26+
return jax.numpy.logical_or(x, jax.numpy.logical_not(w))
27+
28+
3329

3430
def soft_and_neuron(w, x):
3531
x = jax.vmap(soft_and_include, 0, 0)(w, x)
3632
return jax.numpy.min(x)
3733

34+
3835
def hard_and_neuron(w, x):
3936
x = jax.vmap(hard_and_include, 0, 0)(w, x)
4037
return jax.lax.reduce(x, True, jax.lax.bitwise_and, [0])
4138

42-
def symbolic_and_neuron(w, x):
43-
# TODO: ensure that this implementation has the same generality over tensors as vmap
44-
if not isinstance(w, list):
45-
raise TypeError(f"Input {x} should be a list")
46-
if not isinstance(x, list):
47-
raise TypeError(f"Input {x} should be a list")
48-
y = [symbolic_and_include(wi, xi) for wi, xi in zip(w, x)]
49-
expression = "(" + str(reduce(lambda a, b: f"{a} and {b}", y)) + ")"
50-
if all(isinstance(yi, bool) for yi in y):
51-
# We know the value of all yis, so we can evaluate the expression
52-
return eval(expression)
53-
return expression
5439

5540
soft_and_layer = jax.vmap(soft_and_neuron, (0, None), 0)
5641

5742
hard_and_layer = jax.vmap(hard_and_neuron, (0, None), 0)
5843

59-
def symbolic_and_layer(w, x):
60-
# TODO: ensure that this implementation has the same generality over tensors as vmap
61-
if not isinstance(w, list):
62-
raise TypeError(f"Input {x} should be a list")
63-
if not isinstance(x, list):
64-
raise TypeError(f"Input {x} should be a list")
65-
return [symbolic_and_neuron(wi, x) for wi in w]
6644

67-
# TODO: investigate better initialization
6845
def initialize_near_to_zero():
46+
# TODO: investigate better initialization
6947
def init(key, shape, dtype):
7048
dtype = jax.dtypes.canonicalize_dtype(dtype)
7149
# Sample from standard normal distribution (zero mean, unit variance)
@@ -76,6 +54,7 @@ def init(key, shape, dtype):
7654
return x
7755
return init
7856

57+
7958
class SoftAndLayer(nn.Module):
8059
"""
8160
A soft-bit AND layer than transforms its inputs along the last dimension.
@@ -91,10 +70,12 @@ class SoftAndLayer(nn.Module):
9170
@nn.compact
9271
def __call__(self, x):
9372
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
94-
weights = self.param('weights', self.weights_init, weights_shape, self.dtype)
73+
weights = self.param('weights', self.weights_init,
74+
weights_shape, self.dtype)
9575
x = jax.numpy.asarray(x, self.dtype)
9676
return soft_and_layer(weights, x)
9777

78+
9879
class HardAndLayer(nn.Module):
9980
"""
10081
A hard-bit And layer that shadows the SoftAndLayer.
@@ -108,26 +89,22 @@ class HardAndLayer(nn.Module):
10889
@nn.compact
10990
def __call__(self, x):
11091
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
111-
weights = self.param('weights', nn.initializers.constant(0.0), weights_shape)
92+
weights = self.param(
93+
'weights', nn.initializers.constant(0.0), weights_shape)
11294
return hard_and_layer(weights, x)
11395

114-
class SymbolicAndLayer(nn.Module):
115-
"""A symbolic And layer than transforms its inputs along the last dimension.
116-
Attributes:
117-
layer_size: The number of neurons in the layer.
118-
"""
119-
layer_size: int
12096

121-
@nn.compact
97+
class SymbolicAndLayer:
98+
def __init__(self, layer_size):
99+
self.layer_size = layer_size
100+
self.hard_and_layer = HardAndLayer(self.layer_size)
101+
122102
def __call__(self, x):
123-
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
124-
weights = self.param('weights', nn.initializers.constant(0.0), weights_shape)
125-
weights = weights.tolist()
126-
if not isinstance(x, list):
127-
raise TypeError(f"Input {x} should be a list")
128-
return symbolic_and_layer(weights, x)
103+
jaxpr = symbolic_generation.make_symbolic_flax_jaxpr(self.hard_and_layer, x)
104+
return symbolic_generation.symbolic_expression(jaxpr, x)
105+
129106

130107
and_layer = neural_logic_net.select(
131-
lambda layer_size, weights_init=initialize_near_to_zero(), dtype=jax.numpy.float32: SoftAndLayer(layer_size, weights_init, dtype),
132-
lambda layer_size, weights_init=initialize_near_to_zero(), dtype=jax.numpy.float32: HardAndLayer(layer_size),
133-
lambda layer_size, weights_init=initialize_near_to_zero(), dtype=jax.numpy.float32: SymbolicAndLayer(layer_size))
108+
lambda layer_size, weights_init=initialize_near_to_zero(), dtype=jax.numpy.float32: SoftAndLayer(layer_size, weights_init, dtype),
109+
lambda layer_size, weights_init=initialize_near_to_zero(), dtype=jax.numpy.float32: HardAndLayer(layer_size),
110+
lambda layer_size, weights_init=initialize_near_to_zero(), dtype=jax.numpy.float32: SymbolicAndLayer(layer_size))

neurallogic/hard_not.py

Lines changed: 9 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import jax
44
from flax import linen as nn
55

6-
from neurallogic import neural_logic_net
6+
from neurallogic import neural_logic_net, symbolic_generation
77

88

99
def soft_not(w: float, x: float) -> float:
@@ -18,57 +18,25 @@ def soft_not(w: float, x: float) -> float:
1818
return 1.0 - w + x * (2.0 * w - 1.0)
1919

2020

21-
@jax.jit
2221
def hard_not(w: bool, x: bool) -> bool:
23-
return ~(x ^ w)
24-
25-
26-
def symbolic_not(w, x):
27-
expression = f"(not({x} ^ {w}))"
28-
# Check if w is of type bool
29-
if isinstance(w, bool) and isinstance(x, bool):
30-
# We know the value of w and x, so we can evaluate the expression
31-
return eval(expression)
32-
# We don't know the value of w or x, so we return the expression
33-
return expression
22+
# ~(x ^ w)
23+
return jax.numpy.logical_not(jax.numpy.logical_xor(x, w))
3424

3525

3626
soft_not_neuron = jax.vmap(soft_not, 0, 0)
3727

3828
hard_not_neuron = jax.vmap(hard_not, 0, 0)
3929

4030

41-
def symbolic_not_neuron(w, x):
42-
# TODO: ensure that this implementation has the same generality over tensors as vmap
43-
if not isinstance(w, list):
44-
raise TypeError(f"Input {x} should be a list")
45-
if not isinstance(x, list):
46-
raise TypeError(f"Input {x} should be a list")
47-
return [symbolic_not(wi, xi) for wi, xi in zip(w, x)]
48-
4931

5032
soft_not_layer = jax.vmap(soft_not_neuron, (0, None), 0)
5133

5234
hard_not_layer = jax.vmap(hard_not_neuron, (0, None), 0)
5335

5436

55-
def symbolic_not_layer(w, x):
56-
# TODO: ensure that this implementation has the same generality over tensors as vmap
57-
if not isinstance(w, list):
58-
raise TypeError(f"Input {x} should be a list")
59-
if not isinstance(x, list):
60-
raise TypeError(f"Input {x} should be a list")
61-
return [symbolic_not_neuron(wi, x) for wi in w]
6237

6338

6439
class SoftNotLayer(nn.Module):
65-
"""
66-
A soft-bit NOT layer than transforms its inputs along the last dimension.
67-
68-
Attributes:
69-
layer_size: The number of neurons in the layer.
70-
weights_init: The initializer function for the weight matrix.
71-
"""
7240
layer_size: int
7341
weights_init: Callable = nn.initializers.uniform(1.0)
7442
dtype: jax.numpy.dtype = jax.numpy.float32
@@ -83,13 +51,6 @@ def __call__(self, x):
8351

8452

8553
class HardNotLayer(nn.Module):
86-
"""
87-
A hard-bit NOT layer that shadows the SoftNotLayer.
88-
This is a convenience class to make it easier to switch between soft and hard logic.
89-
90-
Attributes:
91-
layer_size: The number of neurons in the layer.
92-
"""
9354
layer_size: int
9455

9556
@nn.compact
@@ -100,22 +61,14 @@ def __call__(self, x):
10061
return hard_not_layer(weights, x)
10162

10263

103-
class SymbolicNotLayer(nn.Module):
104-
"""A symbolic NOT layer than transforms its inputs along the last dimension.
105-
Attributes:
106-
layer_size: The number of neurons in the layer.
107-
"""
108-
layer_size: int
64+
class SymbolicNotLayer:
65+
def __init__(self, layer_size):
66+
self.layer_size = layer_size
67+
self.hard_not_layer = HardNotLayer(self.layer_size)
10968

110-
@nn.compact
11169
def __call__(self, x):
112-
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
113-
weights = self.param(
114-
'weights', nn.initializers.constant(0.0), weights_shape)
115-
weights = weights.tolist()
116-
if not isinstance(x, list):
117-
raise TypeError(f"Input {x} should be a list")
118-
return symbolic_not_layer(weights, x)
70+
jaxpr = symbolic_generation.make_symbolic_flax_jaxpr(self.hard_not_layer, x)
71+
return symbolic_generation.symbolic_expression(jaxpr, x)
11972

12073

12174
not_layer = neural_logic_net.select(

neurallogic/hard_or.py

Lines changed: 10 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import jax
55
from flax import linen as nn
66

7-
from neurallogic import neural_logic_net
7+
from neurallogic import neural_logic_net, symbolic_generation
88

99

1010
def soft_or_include(w: float, x: float) -> float:
@@ -18,18 +18,10 @@ def soft_or_include(w: float, x: float) -> float:
1818
w = jax.numpy.clip(w, 0.0, 1.0)
1919
return 1.0 - jax.numpy.maximum(1.0 - x, 1.0 - w)
2020

21-
@jax.jit
22-
def hard_or_include(w: bool, x: bool) -> bool:
23-
return x & w
2421

25-
def symbolic_or_include(w, x):
26-
expression = f"({x} and {w})"
27-
# Check if w is of type bool
28-
if isinstance(w, bool) and isinstance(x, bool):
29-
# We know the value of w and x, so we can evaluate the expression
30-
return eval(expression)
31-
# We don't know the value of w or x, so we return the expression
32-
return expression
22+
def hard_or_include(w, x):
23+
return jax.numpy.logical_and(x, w)
24+
3325

3426
def soft_or_neuron(w, x):
3527
x = jax.vmap(soft_or_include, 0, 0)(w, x)
@@ -39,31 +31,11 @@ def hard_or_neuron(w, x):
3931
x = jax.vmap(hard_or_include, 0, 0)(w, x)
4032
return jax.lax.reduce(x, False, jax.lax.bitwise_or, [0])
4133

42-
def symbolic_or_neuron(w, x):
43-
# TODO: ensure that this implementation has the same generality over tensors as vmap
44-
if not isinstance(w, list):
45-
raise TypeError(f"Input {x} should be a list")
46-
if not isinstance(x, list):
47-
raise TypeError(f"Input {x} should be a list")
48-
y = [symbolic_or_include(wi, xi) for wi, xi in zip(w, x)]
49-
expression = "(" + str(reduce(lambda a, b: f"{a} or {b}", y)) + ")"
50-
if all(isinstance(yi, bool) for yi in y):
51-
# We know the value of all yis, so we can evaluate the expression
52-
return eval(expression)
53-
return expression
5434

5535
soft_or_layer = jax.vmap(soft_or_neuron, (0, None), 0)
5636

5737
hard_or_layer = jax.vmap(hard_or_neuron, (0, None), 0)
5838

59-
def symbolic_or_layer(w, x):
60-
# TODO: ensure that this implementation has the same generality over tensors as vmap
61-
if not isinstance(w, list):
62-
raise TypeError(f"Input {x} should be a list")
63-
if not isinstance(x, list):
64-
raise TypeError(f"Input {x} should be a list")
65-
return [symbolic_or_neuron(wi, x) for wi in w]
66-
6739
# TODO: investigate better initialization
6840
def initialize_near_to_one():
6941
def init(key, shape, dtype):
@@ -77,13 +49,6 @@ def init(key, shape, dtype):
7749
return init
7850

7951
class SoftOrLayer(nn.Module):
80-
"""
81-
A soft-bit Or layer than transforms its inputs along the last dimension.
82-
83-
Attributes:
84-
layer_size: The number of neurons in the layer.
85-
weights_init: The initializer function for the weight matrix.
86-
"""
8752
layer_size: int
8853
weights_init: Callable = initialize_near_to_one()
8954
dtype: jax.numpy.dtype = jax.numpy.float32
@@ -96,13 +61,6 @@ def __call__(self, x):
9661
return soft_or_layer(weights, x)
9762

9863
class HardOrLayer(nn.Module):
99-
"""
100-
A hard-bit Or layer that shadows the SoftAndLayer.
101-
This is a convenience class to make it easier to switch between soft and hard logic.
102-
103-
Attributes:
104-
layer_size: The number of neurons in the layer.
105-
"""
10664
layer_size: int
10765

10866
@nn.compact
@@ -111,21 +69,14 @@ def __call__(self, x):
11169
weights = self.param('weights', nn.initializers.constant(0.0), weights_shape)
11270
return hard_or_layer(weights, x)
11371

114-
class SymbolicOrLayer(nn.Module):
115-
"""A symbolic Or layer than transforms its inputs along the last dimension.
116-
Attributes:
117-
layer_size: The number of neurons in the layer.
118-
"""
119-
layer_size: int
72+
class SymbolicOrLayer:
73+
def __init__(self, layer_size):
74+
self.layer_size = layer_size
75+
self.hard_or_layer = HardOrLayer(self.layer_size)
12076

121-
@nn.compact
12277
def __call__(self, x):
123-
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
124-
weights = self.param('weights', nn.initializers.constant(0.0), weights_shape)
125-
weights = weights.tolist()
126-
if not isinstance(x, list):
127-
raise TypeError(f"Input {x} should be a list")
128-
return symbolic_or_layer(weights, x)
78+
jaxpr = symbolic_generation.make_symbolic_flax_jaxpr(self.hard_or_layer, x)
79+
return symbolic_generation.symbolic_expression(jaxpr, x)
12980

13081
or_layer = neural_logic_net.select(
13182
lambda layer_size, weights_init=initialize_near_to_one(), dtype=jax.numpy.float32: SoftOrLayer(layer_size, weights_init, dtype),

0 commit comments

Comments
 (0)