Skip to content

Commit 1fdbb61

Browse files
committed
add experimental xor neuron
1 parent 3bcb6d9 commit 1fdbb61

File tree

11 files changed

+407
-42
lines changed

11 files changed

+407
-42
lines changed

neurallogic/hard_and.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class HardAndLayer(nn.Module):
8989
def __call__(self, x):
9090
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
9191
weights = self.param(
92-
'bit_weights', nn.initializers.constant(0.0), weights_shape)
92+
'bit_weights', nn.initializers.constant(True), weights_shape)
9393
return hard_and_layer(weights, x)
9494

9595

@@ -105,5 +105,5 @@ def __call__(self, x):
105105

106106
and_layer = neural_logic_net.select(
107107
lambda layer_size, weights_init=initialize_near_to_zero(), dtype=jax.numpy.float32: SoftAndLayer(layer_size, weights_init, dtype),
108-
lambda layer_size, weights_init=initialize_near_to_zero(), dtype=jax.numpy.float32: HardAndLayer(layer_size),
109-
lambda layer_size, weights_init=initialize_near_to_zero(), dtype=jax.numpy.float32: SymbolicAndLayer(layer_size))
108+
lambda layer_size, weights_init=nn.initializers.constant(True), dtype=jax.numpy.float32: HardAndLayer(layer_size),
109+
lambda layer_size, weights_init=nn.initializers.constant(True), dtype=jax.numpy.float32: SymbolicAndLayer(layer_size))

neurallogic/hard_not.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,12 @@ def __call__(self, x):
4848

4949
class HardNotLayer(nn.Module):
5050
layer_size: int
51+
weights_init: Callable = nn.initializers.constant(True)
5152

5253
@nn.compact
5354
def __call__(self, x):
5455
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
55-
weights = self.param("bit_weights", nn.initializers.constant(0.0), weights_shape)
56+
weights = self.param("bit_weights", self.weights_init, weights_shape)
5657
return hard_not_layer(weights, x)
5758

5859

@@ -67,13 +68,7 @@ def __call__(self, x):
6768

6869

6970
not_layer = neural_logic_net.select(
70-
lambda layer_size, weights_init=nn.initializers.uniform(
71-
1.0
72-
), dtype=jax.numpy.float32: SoftNotLayer(layer_size, weights_init, dtype),
73-
lambda layer_size, weights_init=nn.initializers.uniform(
74-
1.0
75-
), dtype=jax.numpy.float32: HardNotLayer(layer_size),
76-
lambda layer_size, weights_init=nn.initializers.uniform(
77-
1.0
78-
), dtype=jax.numpy.float32: SymbolicNotLayer(layer_size),
71+
lambda layer_size, weights_init=nn.initializers.uniform(1.0), dtype=jax.numpy.float32: SoftNotLayer(layer_size, weights_init, dtype),
72+
lambda layer_size, weights_init=nn.initializers.uniform(1.0), dtype=jax.numpy.float32: HardNotLayer(layer_size),
73+
lambda layer_size, weights_init=nn.initializers.uniform(1.0), dtype=jax.numpy.float32: SymbolicNotLayer(layer_size),
7974
)

neurallogic/hard_or.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from functools import reduce
21
from typing import Callable
32

43
import jax
@@ -12,7 +11,7 @@ def soft_or_include(w: float, x: float) -> float:
1211
w > 0.5 implies the and operation is active, else inactive
1312
1413
Assumes x is in [0, 1]
15-
14+
1615
Corresponding hard logic: b AND w
1716
"""
1817
w = jax.numpy.clip(w, 0.0, 1.0)
@@ -27,6 +26,7 @@ def soft_or_neuron(w, x):
2726
x = jax.vmap(soft_or_include, 0, 0)(w, x)
2827
return jax.numpy.max(x)
2928

29+
3030
def hard_or_neuron(w, x):
3131
x = jax.vmap(hard_or_include, 0, 0)(w, x)
3232
return jax.lax.reduce(x, False, jax.lax.bitwise_or, [0])
@@ -37,6 +37,8 @@ def hard_or_neuron(w, x):
3737
hard_or_layer = jax.vmap(hard_or_neuron, (0, None), 0)
3838

3939
# TODO: investigate better initialization
40+
41+
4042
def initialize_near_to_one():
4143
def init(key, shape, dtype):
4244
dtype = jax.dtypes.canonicalize_dtype(dtype)
@@ -48,6 +50,7 @@ def init(key, shape, dtype):
4850
return x
4951
return init
5052

53+
5154
class SoftOrLayer(nn.Module):
5255
layer_size: int
5356
weights_init: Callable = initialize_near_to_one()
@@ -56,29 +59,37 @@ class SoftOrLayer(nn.Module):
5659
@nn.compact
5760
def __call__(self, x):
5861
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
59-
weights = self.param('bit_weights', self.weights_init, weights_shape, self.dtype)
62+
weights = self.param(
63+
'bit_weights', self.weights_init, weights_shape, self.dtype)
6064
x = jax.numpy.asarray(x, self.dtype)
6165
return soft_or_layer(weights, x)
6266

67+
6368
class HardOrLayer(nn.Module):
6469
layer_size: int
6570

6671
@nn.compact
6772
def __call__(self, x):
6873
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
69-
weights = self.param('bit_weights', nn.initializers.constant(0.0), weights_shape)
74+
weights = self.param(
75+
'bit_weights', nn.initializers.constant(True), weights_shape)
7076
return hard_or_layer(weights, x)
7177

78+
7279
class SymbolicOrLayer:
7380
def __init__(self, layer_size):
7481
self.layer_size = layer_size
7582
self.hard_or_layer = HardOrLayer(self.layer_size)
7683

7784
def __call__(self, x):
78-
jaxpr = symbolic_generation.make_symbolic_flax_jaxpr(self.hard_or_layer, x)
85+
jaxpr = symbolic_generation.make_symbolic_flax_jaxpr(
86+
self.hard_or_layer, x)
7987
return symbolic_generation.symbolic_expression(jaxpr, x)
8088

89+
8190
or_layer = neural_logic_net.select(
82-
lambda layer_size, weights_init=initialize_near_to_one(), dtype=jax.numpy.float32: SoftOrLayer(layer_size, weights_init, dtype),
83-
lambda layer_size, weights_init=initialize_near_to_one(), dtype=jax.numpy.float32: HardOrLayer(layer_size),
84-
lambda layer_size, weights_init=initialize_near_to_one(), dtype=jax.numpy.float32: SymbolicOrLayer(layer_size))
91+
lambda layer_size, weights_init=initialize_near_to_one(
92+
), dtype=jax.numpy.float32: SoftOrLayer(layer_size, weights_init, dtype),
93+
lambda layer_size, weights_init=nn.initializers.constant(
94+
True), dtype=jax.numpy.float32: HardOrLayer(layer_size),
95+
lambda layer_size, weights_init=nn.initializers.constant(True), dtype=jax.numpy.float32: SymbolicOrLayer(layer_size))

neurallogic/hard_xor.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
from typing import Callable
2+
3+
import jax
4+
from flax import linen as nn
5+
6+
from neurallogic import neural_logic_net, symbolic_generation
7+
8+
9+
def soft_xor_include(w: float, x: float) -> float:
10+
"""
11+
w > 0.5 implies the and operation is active, else inactive
12+
13+
Assumes x is in [0, 1]
14+
15+
Corresponding hard logic: b AND w
16+
"""
17+
w = jax.numpy.clip(w, 0.0, 1.0)
18+
return 1.0 - jax.numpy.maximum(1.0 - x, 1.0 - w)
19+
20+
21+
def hard_xor_include(w, x):
22+
return jax.numpy.logical_and(x, w)
23+
24+
25+
def soft_xor_neuron(w, x):
26+
# Conditionally include input bits, according to weights
27+
x = jax.vmap(soft_xor_include, 0, 0)(w, x)
28+
# Compute the most sensitive bit
29+
margins = jax.vmap(lambda x: jax.numpy.abs(0.5 - x))(x)
30+
sensitive_bit_index = jax.numpy.argmin(margins)
31+
sensitive_bit = jax.numpy.take(x, sensitive_bit_index)
32+
# Compute the logical xor of the bits
33+
hard_x = jax.vmap(lambda x: jax.numpy.where(x > 0.5, True, False))(x)
34+
logical_xor = jax.lax.reduce(hard_x, False, jax.numpy.logical_xor, (0,))
35+
# Compute the representative bit
36+
hard_sensitive_bit = jax.numpy.where(sensitive_bit > 0.5, True, False)
37+
representative_bit = jax.numpy.where(logical_xor == hard_sensitive_bit,
38+
sensitive_bit,
39+
1.0 - sensitive_bit
40+
)
41+
return representative_bit
42+
43+
44+
def hard_xor_neuron(w, x):
45+
x = jax.vmap(hard_xor_include, 0, 0)(w, x)
46+
return jax.lax.reduce(x, False, jax.lax.bitwise_xor, [0])
47+
48+
49+
soft_xor_layer = jax.vmap(soft_xor_neuron, (0, None), 0)
50+
51+
hard_xor_layer = jax.vmap(hard_xor_neuron, (0, None), 0)
52+
53+
54+
class SoftXorLayer(nn.Module):
55+
layer_size: int
56+
weights_init: Callable = nn.initializers.uniform(
57+
1.0) # TODO: investigate better initialization
58+
dtype: jax.numpy.dtype = jax.numpy.float32
59+
60+
@nn.compact
61+
def __call__(self, x):
62+
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
63+
weights = self.param(
64+
'bit_weights', self.weights_init, weights_shape, self.dtype)
65+
x = jax.numpy.asarray(x, self.dtype)
66+
return soft_xor_layer(weights, x)
67+
68+
69+
class HardXorLayer(nn.Module):
70+
layer_size: int
71+
72+
@nn.compact
73+
def __call__(self, x):
74+
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
75+
weights = self.param(
76+
'bit_weights', nn.initializers.constant(True), weights_shape)
77+
return hard_xor_layer(weights, x)
78+
79+
80+
class SymbolicXorLayer:
81+
def __init__(self, layer_size):
82+
self.layer_size = layer_size
83+
self.hard_xor_layer = HardXorLayer(self.layer_size)
84+
85+
def __call__(self, x):
86+
jaxpr = symbolic_generation.make_symbolic_flax_jaxpr(
87+
self.hard_xor_layer, x)
88+
return symbolic_generation.symbolic_expression(jaxpr, x)
89+
90+
91+
xor_layer = neural_logic_net.select(
92+
lambda layer_size, weights_init=nn.initializers.uniform(
93+
1.0), dtype=jax.numpy.float32: SoftXorLayer(layer_size, weights_init, dtype),
94+
lambda layer_size, weights_init=nn.initializers.constant(
95+
True), dtype=jax.numpy.float32: HardXorLayer(layer_size),
96+
lambda layer_size, weights_init=nn.initializers.constant(True), dtype=jax.numpy.float32: SymbolicXorLayer(layer_size))

neurallogic/real_encoder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from neurallogic import neural_logic_net, symbolic_generation
77

8+
# TODO: perhaps this can be simplified with a simple multiplication?
9+
# TODO: implement a soft_real_decoder that can perhaps replace the port count approach
810

911
def soft_real_encoder(t: float, x: float) -> float:
1012
eps = 0.0000001

neurallogic/symbolic_generation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def symbolic_bind(prim, *args, **params):
4141
'not': symbolic_primitives.symbolic_not,
4242
'reduce_and': symbolic_primitives.symbolic_reduce_and,
4343
'reduce_or': symbolic_primitives.symbolic_reduce_or,
44+
'reduce_xor': symbolic_primitives.symbolic_reduce_xor,
4445
'reduce_sum': symbolic_primitives.symbolic_reduce_sum,
4546
'select_n': symbolic_primitives.symbolic_select_n,
4647
}[prim.name](*args, **params)

neurallogic/symbolic_primitives.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,23 @@ def symbolic_reduce_or(*args, **kwargs):
262262
)
263263

264264

265+
def symbolic_reduce_xor(*args, **kwargs):
266+
if all_concrete_values([*args]):
267+
return lax_reference.reduce(
268+
*args,
269+
init_value=False,
270+
computation=numpy.logical_xor,
271+
dimensions=kwargs['axes'],
272+
)
273+
else:
274+
return symbolic_reduce(
275+
*args,
276+
init_value='False',
277+
computation=symbolic_xor,
278+
dimensions=kwargs['axes'],
279+
)
280+
281+
265282
def symbolic_reduce_sum(*args, **kwargs):
266283
if all_concrete_values([*args]):
267284
return lax_reference.reduce(

tests/test_hard_or.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import jax
2-
import jax.numpy as jnp
32
import numpy
43
import optax
54
from flax import linen as nn
@@ -134,7 +133,7 @@ def test_net(type, x):
134133
input = jax.numpy.array(x)
135134
output = jax.numpy.array(y)
136135

137-
# Train the and layer
136+
# Train the or layer
138137
tx = optax.sgd(0.1)
139138
state = train_state.TrainState.create(apply_fn=jax.vmap(
140139
soft.apply, in_axes=(None, 0)), params=weights, tx=tx)

0 commit comments

Comments
 (0)