Skip to content

Commit bd9f84e

Browse files
authored
Merge pull request #52 from github/jaxpr-1
symbolic interpretation of jaxpr -- part 1
2 parents ff3c39a + 8efc75e commit bd9f84e

15 files changed

+1191
-387
lines changed

.devcontainer/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ numpy
22
pandas
33
pytest
44
jupyter
5+
plum-dispatch
56

67
tensorflow
78
tensorflow_datasets

neurallogic/hard_not.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,18 @@ def soft_not(w: float, x: float) -> float:
1111
w > 0.5 implies the not operation is inactive, else active
1212
1313
Assumes x is in [0, 1]
14-
14+
1515
Corresponding hard logic: ! (x XOR w)
1616
"""
1717
w = jax.numpy.clip(w, 0.0, 1.0)
1818
return 1.0 - w + x * (2.0 * w - 1.0)
1919

20+
2021
@jax.jit
2122
def hard_not(w: bool, x: bool) -> bool:
2223
return ~(x ^ w)
2324

25+
2426
def symbolic_not(w, x):
2527
expression = f"(not({x} ^ {w}))"
2628
# Check if w is of type bool
@@ -30,10 +32,12 @@ def symbolic_not(w, x):
3032
# We don't know the value of w or x, so we return the expression
3133
return expression
3234

35+
3336
soft_not_neuron = jax.vmap(soft_not, 0, 0)
3437

3538
hard_not_neuron = jax.vmap(hard_not, 0, 0)
3639

40+
3741
def symbolic_not_neuron(w, x):
3842
# TODO: ensure that this implementation has the same generality over tensors as vmap
3943
if not isinstance(w, list):
@@ -42,10 +46,12 @@ def symbolic_not_neuron(w, x):
4246
raise TypeError(f"Input {x} should be a list")
4347
return [symbolic_not(wi, xi) for wi, xi in zip(w, x)]
4448

49+
4550
soft_not_layer = jax.vmap(soft_not_neuron, (0, None), 0)
4651

4752
hard_not_layer = jax.vmap(hard_not_neuron, (0, None), 0)
4853

54+
4955
def symbolic_not_layer(w, x):
5056
# TODO: ensure that this implementation has the same generality over tensors as vmap
5157
if not isinstance(w, list):
@@ -54,6 +60,7 @@ def symbolic_not_layer(w, x):
5460
raise TypeError(f"Input {x} should be a list")
5561
return [symbolic_not_neuron(wi, x) for wi in w]
5662

63+
5764
class SoftNotLayer(nn.Module):
5865
"""
5966
A soft-bit NOT layer than transforms its inputs along the last dimension.
@@ -69,10 +76,12 @@ class SoftNotLayer(nn.Module):
6976
@nn.compact
7077
def __call__(self, x):
7178
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
72-
weights = self.param('weights', self.weights_init, weights_shape, self.dtype)
79+
weights = self.param('weights', self.weights_init,
80+
weights_shape, self.dtype)
7381
x = jax.numpy.asarray(x, self.dtype)
7482
return soft_not_layer(weights, x)
7583

84+
7685
class HardNotLayer(nn.Module):
7786
"""
7887
A hard-bit NOT layer that shadows the SoftNotLayer.
@@ -86,9 +95,11 @@ class HardNotLayer(nn.Module):
8695
@nn.compact
8796
def __call__(self, x):
8897
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
89-
weights = self.param('weights', nn.initializers.constant(0.0), weights_shape)
98+
weights = self.param(
99+
'weights', nn.initializers.constant(0.0), weights_shape)
90100
return hard_not_layer(weights, x)
91101

102+
92103
class SymbolicNotLayer(nn.Module):
93104
"""A symbolic NOT layer than transforms its inputs along the last dimension.
94105
Attributes:
@@ -99,13 +110,17 @@ class SymbolicNotLayer(nn.Module):
99110
@nn.compact
100111
def __call__(self, x):
101112
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
102-
weights = self.param('weights', nn.initializers.constant(0.0), weights_shape)
113+
weights = self.param(
114+
'weights', nn.initializers.constant(0.0), weights_shape)
103115
weights = weights.tolist()
104116
if not isinstance(x, list):
105117
raise TypeError(f"Input {x} should be a list")
106118
return symbolic_not_layer(weights, x)
107119

120+
108121
not_layer = neural_logic_net.select(
109-
lambda layer_size, weights_init=nn.initializers.uniform(1.0), dtype=jax.numpy.float32: SoftNotLayer(layer_size, weights_init, dtype),
110-
lambda layer_size, weights_init=nn.initializers.uniform(1.0), dtype=jax.numpy.float32: HardNotLayer(layer_size),
122+
lambda layer_size, weights_init=nn.initializers.uniform(
123+
1.0), dtype=jax.numpy.float32: SoftNotLayer(layer_size, weights_init, dtype),
124+
lambda layer_size, weights_init=nn.initializers.uniform(
125+
1.0), dtype=jax.numpy.float32: HardNotLayer(layer_size),
111126
lambda layer_size, weights_init=nn.initializers.uniform(1.0), dtype=jax.numpy.float32: SymbolicNotLayer(layer_size))

neurallogic/harden.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@ def hard_weights(weights):
3131

3232
def symbolic_weights(weights):
3333
return flax.core.FrozenDict(map_keys_nested(lambda str: str.replace("Soft", "Symbolic"), harden(weights.unfreeze())))
34+

neurallogic/sym_gen.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import jax
2+
import jax._src.lax_reference as lax_reference
3+
from jax import core
4+
from jax._src.util import safe_map
5+
import numpy
6+
from neurallogic import symbolic_primitives
7+
8+
9+
def symbolic_bind(prim, *args, **params):
10+
#print("\n---symbolic_bind:")
11+
#print("primitive: ", prim.name)
12+
#print("args: ", args)
13+
#print("params: ", params)
14+
symbolic_outvals = {
15+
'and': symbolic_primitives.symbolic_and,
16+
'broadcast_in_dim': symbolic_primitives.symbolic_broadcast_in_dim,
17+
'xor': symbolic_primitives.symbolic_xor,
18+
'not': symbolic_primitives.symbolic_not,
19+
'reshape': lax_reference.reshape,
20+
'reduce_or': symbolic_primitives.symbolic_reduce_or,
21+
'reduce_sum': symbolic_primitives.symbolic_reduce_sum,
22+
'convert_element_type': symbolic_primitives.symbolic_convert_element_type
23+
}[prim.name](*args, **params)
24+
return symbolic_outvals
25+
26+
27+
28+
def eval_jaxpr(symbolic, jaxpr, consts, *args):
29+
"""Evaluates a jaxpr by interpreting it as Python code.
30+
31+
Parameters
32+
----------
33+
symbolic : bool
34+
Whether to return symbolic values or concrete values. If symbolic is
35+
True, returns symbolic values, and if symbolic is False, returns
36+
concrete values.
37+
jaxpr : Jaxpr
38+
The jaxpr to interpret.
39+
consts : tuple
40+
Constant values for the jaxpr.
41+
args : tuple
42+
Arguments for the jaxpr.
43+
44+
Returns
45+
-------
46+
out : tuple
47+
The result of evaluating the jaxpr.
48+
"""
49+
50+
# Mapping from variable -> value
51+
env = {}
52+
symbolic_env = {}
53+
54+
def read(var):
55+
# Literals are values baked into the Jaxpr
56+
if type(var) is core.Literal:
57+
return var.val
58+
return env[var]
59+
60+
def symbolic_read(var):
61+
return symbolic_env[var]
62+
63+
def write(var, val):
64+
env[var] = val
65+
66+
def symbolic_write(var, val):
67+
symbolic_env[var] = val
68+
69+
# Bind args and consts to environment
70+
if not symbolic:
71+
safe_map(write, jaxpr.invars, args)
72+
safe_map(write, jaxpr.constvars, consts)
73+
safe_map(symbolic_write, jaxpr.invars, args)
74+
safe_map(symbolic_write, jaxpr.constvars, consts)
75+
76+
def eval_jaxpr_impl(jaxpr):
77+
# Loop through equations and evaluate primitives using `bind`
78+
for eqn in jaxpr.eqns:
79+
# Read inputs to equation from environment
80+
if not symbolic:
81+
invals = safe_map(read, eqn.invars)
82+
symbolic_invals = safe_map(symbolic_read, eqn.invars)
83+
prim = eqn.primitive
84+
if type(prim) is jax.core.CallPrimitive:
85+
# print(f"call primitive: {prim.name}")
86+
call_jaxpr = eqn.params['call_jaxpr']
87+
if not symbolic:
88+
safe_map(write, call_jaxpr.invars, map(read, eqn.invars))
89+
safe_map(symbolic_write, call_jaxpr.invars,
90+
map(symbolic_read, eqn.invars))
91+
eval_jaxpr_impl(call_jaxpr)
92+
if not symbolic:
93+
safe_map(write, eqn.outvars, map(read, call_jaxpr.outvars))
94+
safe_map(symbolic_write, eqn.outvars, map(
95+
symbolic_read, call_jaxpr.outvars))
96+
else:
97+
# print(f"primitive: {prim.name}")
98+
if not symbolic:
99+
outvals = prim.bind(*invals, **eqn.params)
100+
symbolic_outvals = symbolic_bind(
101+
prim, *symbolic_invals, **eqn.params)
102+
# Primitives may return multiple outputs or not
103+
if not prim.multiple_results:
104+
if not symbolic:
105+
outvals = [outvals]
106+
symbolic_outvals = [symbolic_outvals]
107+
if not symbolic:
108+
#print(f"outvals: {type(outvals)}: {outvals}")
109+
#print(
110+
# f"symbolic_outvals: {type(symbolic_outvals)}: {symbolic_outvals}")
111+
# Check that the concrete and symbolic values are equal
112+
assert numpy.array_equal(
113+
numpy.array(outvals), symbolic_outvals)
114+
# Write the results of the primitive into the environment
115+
if not symbolic:
116+
safe_map(write, eqn.outvars, outvals)
117+
safe_map(symbolic_write, eqn.outvars, symbolic_outvals)
118+
119+
# Read the final result of the Jaxpr from the environment
120+
eval_jaxpr_impl(jaxpr)
121+
if not symbolic:
122+
return safe_map(read, jaxpr.outvars)[0]
123+
else:
124+
return safe_map(symbolic_read, jaxpr.outvars)[0]
125+
126+
127+
def eval_jaxpr_concrete(jaxpr, *args):
128+
return eval_jaxpr(False, jaxpr.jaxpr, jaxpr.literals, *args)
129+
130+
131+
def eval_jaxpr_symbolic(jaxpr, *args):
132+
symbolic_jaxpr_literals = safe_map(lambda x: numpy.array(x, dtype=object), jaxpr.literals)
133+
symbolic_jaxpr_literals = symbolic_primitives.to_boolean_symbolic_values(symbolic_jaxpr_literals)
134+
return eval_jaxpr(True, jaxpr.jaxpr, symbolic_jaxpr_literals, *args)
135+

0 commit comments

Comments
 (0)