Skip to content

Commit ce5195d

Browse files
committed
wip: adding a hard_bit layer
1 parent c1774da commit ce5195d

File tree

8 files changed

+552
-186
lines changed

8 files changed

+552
-186
lines changed

neurallogic/hard_bit.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import jax
2+
3+
4+
def soft_bit(t: float, x: float) -> float:
5+
# x should be in [0, 1]
6+
t = jax.numpy.clip(t, 0.0, 1.0)
7+
return jax.numpy.where(
8+
x == t,
9+
0.5,
10+
jax.numpy.where(
11+
x < t,
12+
(1.0 / (2.0 * t)) * x,
13+
(1.0 / (2.0 * (1.0 - t))) * (x + 1.0 - 2.0 * t),
14+
),
15+
)
16+
17+
18+
def hard_bit(t: float, x: float) -> bool:
19+
# t and x must be floats
20+
return jax.numpy.where(soft_bit(t, x) > 0.5, True, False)
21+
22+
23+
soft_bit_neuron = jax.vmap(soft_bit, in_axes=(0, None))
24+
25+
hard_bit_neuron = jax.vmap(hard_bit, in_axes=(0, None))
26+
27+
soft_bit_layer = jax.vmap(soft_bit_neuron, (0, 0), 0)
28+
29+
hard_bit_layer = jax.vmap(hard_bit_neuron, (0, 0), 0)

neurallogic/hard_not.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,9 @@ def hard_not(w: bool, x: bool) -> bool:
2828
hard_not_neuron = jax.vmap(hard_not, 0, 0)
2929

3030

31+
soft_not_layer = jax.vmap(soft_not_neuron, (0, 0), 0)
3132

32-
soft_not_layer = jax.vmap(soft_not_neuron, (0, None), 0)
33-
34-
hard_not_layer = jax.vmap(hard_not_neuron, (0, None), 0)
35-
36-
33+
hard_not_layer = jax.vmap(hard_not_neuron, (0, 0), 0)
3734

3835

3936
class SoftNotLayer(nn.Module):
@@ -44,8 +41,7 @@ class SoftNotLayer(nn.Module):
4441
@nn.compact
4542
def __call__(self, x):
4643
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
47-
weights = self.param('weights', self.weights_init,
48-
weights_shape, self.dtype)
44+
weights = self.param("weights", self.weights_init, weights_shape, self.dtype)
4945
x = jax.numpy.asarray(x, self.dtype)
5046
return soft_not_layer(weights, x)
5147

@@ -56,8 +52,7 @@ class HardNotLayer(nn.Module):
5652
@nn.compact
5753
def __call__(self, x):
5854
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
59-
weights = self.param(
60-
'weights', nn.initializers.constant(0.0), weights_shape)
55+
weights = self.param("weights", nn.initializers.constant(0.0), weights_shape)
6156
return hard_not_layer(weights, x)
6257

6358

@@ -73,7 +68,12 @@ def __call__(self, x):
7368

7469
not_layer = neural_logic_net.select(
7570
lambda layer_size, weights_init=nn.initializers.uniform(
76-
1.0), dtype=jax.numpy.float32: SoftNotLayer(layer_size, weights_init, dtype),
71+
1.0
72+
), dtype=jax.numpy.float32: SoftNotLayer(layer_size, weights_init, dtype),
73+
lambda layer_size, weights_init=nn.initializers.uniform(
74+
1.0
75+
), dtype=jax.numpy.float32: HardNotLayer(layer_size),
7776
lambda layer_size, weights_init=nn.initializers.uniform(
78-
1.0), dtype=jax.numpy.float32: HardNotLayer(layer_size),
79-
lambda layer_size, weights_init=nn.initializers.uniform(1.0), dtype=jax.numpy.float32: SymbolicNotLayer(layer_size))
77+
1.0
78+
), dtype=jax.numpy.float32: SymbolicNotLayer(layer_size),
79+
)

neurallogic/harden.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@
88
def harden_float(x: float) -> bool:
99
return x > 0.5
1010

11+
1112
harden_array = jax.vmap(harden_float, 0, 0)
1213

1314
@dispatch
1415
def harden(x: float):
16+
if numpy.isnan(x):
17+
return x
1518
return harden_float(x)
1619

1720
@dispatch

neurallogic/symbolic_generation.py

Lines changed: 73 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,51 @@
11
import numpy
22
import jax
3-
import jax._src.lax_reference as lax_reference
43
from jax import core
54
from jax._src.util import safe_map
65
import flax
76
from neurallogic import symbolic_primitives
87
from plum import dispatch
98
import typing
10-
from typing import (Any, Mapping)
11-
12-
# TODO: rename this file to symbolic.py
9+
from typing import Any, Mapping
1310

1411

1512
def symbolic_bind(prim, *args, **params):
1613
# print("\nprimitive: ", prim.name)
1714
# print("args: ", args)
1815
# print("params: ", params)
1916
symbolic_outvals = {
20-
'broadcast_in_dim': symbolic_primitives.symbolic_broadcast_in_dim,
21-
'reshape': lax_reference.reshape,
22-
'convert_element_type': symbolic_primitives.symbolic_convert_element_type,
23-
'and': symbolic_primitives.symbolic_and,
24-
'or': symbolic_primitives.symbolic_or,
25-
'xor': symbolic_primitives.symbolic_xor,
26-
'not': symbolic_primitives.symbolic_not,
27-
'ne': symbolic_primitives.symbolic_ne,
28-
'gt': symbolic_primitives.symbolic_gt,
29-
'reduce_and': symbolic_primitives.symbolic_reduce_and,
30-
'reduce_or': symbolic_primitives.symbolic_reduce_or,
31-
'reduce_sum': symbolic_primitives.symbolic_reduce_sum,
17+
"broadcast_in_dim": symbolic_primitives.symbolic_broadcast_in_dim,
18+
"reshape": symbolic_primitives.symbolic_reshape,
19+
"transpose": symbolic_primitives.symbolic_transpose,
20+
"convert_element_type": symbolic_primitives.symbolic_convert_element_type,
21+
"eq": symbolic_primitives.symbolic_eq,
22+
"ne": symbolic_primitives.symbolic_ne,
23+
"le": symbolic_primitives.symbolic_le,
24+
"lt": symbolic_primitives.symbolic_lt,
25+
"gt": symbolic_primitives.symbolic_gt,
26+
"add": symbolic_primitives.symbolic_add,
27+
"sub": symbolic_primitives.symbolic_sub,
28+
"mul": symbolic_primitives.symbolic_mul,
29+
"div": symbolic_primitives.symbolic_div,
30+
"max": symbolic_primitives.symbolic_max,
31+
"min": symbolic_primitives.symbolic_min,
32+
"and": symbolic_primitives.symbolic_and,
33+
"or": symbolic_primitives.symbolic_or,
34+
"xor": symbolic_primitives.symbolic_xor,
35+
"not": symbolic_primitives.symbolic_not,
36+
"reduce_and": symbolic_primitives.symbolic_reduce_and,
37+
"reduce_or": symbolic_primitives.symbolic_reduce_or,
38+
"reduce_sum": symbolic_primitives.symbolic_reduce_sum,
39+
"select_n": symbolic_primitives.symbolic_select_n,
3240
}[prim.name](*args, **params)
3341
return symbolic_outvals
3442

43+
3544
def scope_put_variable(self, col: str, name: str, value: Any):
3645
variables = self._collection(col)
3746

3847
def put(target, key, val):
39-
if (key in target and isinstance(target[key], dict) and
40-
isinstance(val, Mapping)):
48+
if key in target and isinstance(target[key], dict) and isinstance(val, Mapping):
4149
for k, v in val.items():
4250
put(target[key], k, v)
4351
else:
@@ -50,11 +58,16 @@ def put_variable(self, col: str, name: str, value: Any):
5058
self.scope._variables = self.scope.variables().unfreeze()
5159
scope_put_variable(self.scope, col, name, value)
5260

61+
5362
def make_symbolic_flax_jaxpr(flax_layer, x):
5463
actual_weights = flax_layer.get_variable("params", "weights")
5564
# Convert actual weights to dummy numeric weights (if needed)
56-
if isinstance(actual_weights, list) or (isinstance(actual_weights, numpy.ndarray) and actual_weights.dtype == object):
57-
numeric_weights = symbolic_primitives.map_at_elements(actual_weights, lambda x: 0)
65+
if isinstance(actual_weights, list) or (
66+
isinstance(actual_weights, numpy.ndarray) and actual_weights.dtype == object
67+
):
68+
numeric_weights = symbolic_primitives.map_at_elements(
69+
actual_weights, lambda x: 0
70+
)
5871
numeric_weights = numpy.asarray(numeric_weights, dtype=numpy.int32)
5972
put_variable(flax_layer, "params", "weights", numeric_weights)
6073
# Convert input to dummy numeric input (if needed)
@@ -130,33 +143,40 @@ def eval_jaxpr_impl(jaxpr):
130143
symbolic_invals = safe_map(symbolic_read, eqn.invars)
131144
prim = eqn.primitive
132145
if type(prim) is jax.core.CallPrimitive:
133-
call_jaxpr = eqn.params['call_jaxpr']
146+
call_jaxpr = eqn.params["call_jaxpr"]
134147
if not symbolic:
135148
safe_map(write, call_jaxpr.invars, map(read, eqn.invars))
136149
try:
137-
safe_map(symbolic_write, call_jaxpr.invars,
138-
map(symbolic_read, eqn.invars))
150+
safe_map(
151+
symbolic_write,
152+
call_jaxpr.invars,
153+
map(symbolic_read, eqn.invars),
154+
)
139155
except:
140156
pass
141157
eval_jaxpr_impl(call_jaxpr)
142158
if not symbolic:
143159
safe_map(write, eqn.outvars, map(read, call_jaxpr.outvars))
144-
safe_map(symbolic_write, eqn.outvars, map(
145-
symbolic_read, call_jaxpr.outvars))
160+
safe_map(
161+
symbolic_write, eqn.outvars, map(symbolic_read, call_jaxpr.outvars)
162+
)
146163
else:
147164
if not symbolic:
148165
outvals = prim.bind(*invals, **eqn.params)
149-
symbolic_outvals = symbolic_bind(
150-
prim, *symbolic_invals, **eqn.params)
166+
symbolic_outvals = symbolic_bind(prim, *symbolic_invals, **eqn.params)
151167
# Primitives may return multiple outputs or not
152168
if not prim.multiple_results:
153169
if not symbolic:
154170
outvals = [outvals]
155171
symbolic_outvals = [symbolic_outvals]
156172
if not symbolic:
157173
# Check that the concrete and symbolic values are equal
158-
assert numpy.array_equal(
159-
numpy.array(outvals), symbolic_outvals)
174+
#print(
175+
# f"outvals: {outvals} and symbolic_outvals: {symbolic_outvals}"
176+
#)
177+
assert numpy.allclose(
178+
numpy.array(outvals), symbolic_outvals, equal_nan=True
179+
)
160180
# Write the results of the primitive into the environment
161181
if not symbolic:
162182
safe_map(write, eqn.outvars, outvals)
@@ -169,24 +189,36 @@ def eval_jaxpr_impl(jaxpr):
169189
else:
170190
return safe_map(symbolic_read, jaxpr.outvars)[0]
171191

192+
172193
# TODO: parameterise these functions by the element conversion function
173194

174195
# TODO: use union types to consolidate these functions
175196
@dispatch
176197
def make_symbolic(x: dict):
177-
return symbolic_primitives.map_at_elements(x, symbolic_primitives.to_boolean_value_string)
198+
return symbolic_primitives.map_at_elements(
199+
x, symbolic_primitives.to_boolean_value_string
200+
)
201+
178202

179203
@dispatch
180204
def make_symbolic(x: list):
181-
return symbolic_primitives.map_at_elements(x, symbolic_primitives.to_boolean_value_string)
205+
return symbolic_primitives.map_at_elements(
206+
x, symbolic_primitives.to_boolean_value_string
207+
)
208+
182209

183210
@dispatch
184211
def make_symbolic(x: numpy.ndarray):
185-
return symbolic_primitives.map_at_elements(x, symbolic_primitives.to_boolean_value_string)
212+
return symbolic_primitives.map_at_elements(
213+
x, symbolic_primitives.to_boolean_value_string
214+
)
215+
186216

187217
@dispatch
188218
def make_symbolic(x: jax.numpy.ndarray):
189-
return symbolic_primitives.map_at_elements(convert_jax_to_numpy_arrays(x), symbolic_primitives.to_boolean_value_string)
219+
return symbolic_primitives.map_at_elements(
220+
convert_jax_to_numpy_arrays(x), symbolic_primitives.to_boolean_value_string
221+
)
190222

191223

192224
@dispatch
@@ -214,25 +246,28 @@ def make_symbolic(x: flax.core.FrozenDict):
214246
x = convert_jax_to_numpy_arrays(x.unfreeze())
215247
return flax.core.FrozenDict(make_symbolic(x))
216248

249+
217250
@dispatch
218251
def make_symbolic(*args):
219252
return tuple([make_symbolic(arg) for arg in args])
220253

254+
221255
@dispatch
222256
def make_symbolic_jaxpr(func: typing.Callable, *args):
223257
return jax.make_jaxpr(lambda *args: func(*args))(*args)
224258

225259

226260
def eval_symbolic(symbolic_function, *args):
227-
if hasattr(symbolic_function, 'literals'):
228-
return eval_jaxpr(False, symbolic_function.jaxpr, symbolic_function.literals, *args)
261+
if hasattr(symbolic_function, "literals"):
262+
return eval_jaxpr(
263+
False, symbolic_function.jaxpr, symbolic_function.literals, *args
264+
)
229265
return eval_jaxpr(False, symbolic_function.jaxpr, [], *args)
230266

231267

232268
def symbolic_expression(jaxpr, *args):
233-
if hasattr(jaxpr, 'literals'):
234-
sym_expr = eval_jaxpr(True, jaxpr.jaxpr,
235-
jaxpr.literals, *args)
269+
if hasattr(jaxpr, "literals"):
270+
sym_expr = eval_jaxpr(True, jaxpr.jaxpr, jaxpr.literals, *args)
236271
else:
237272
sym_expr = eval_jaxpr(True, jaxpr.jaxpr, [], *args)
238273
return sym_expr

0 commit comments

Comments
 (0)