Skip to content

Commit 4ba18f2

Browse files
authored
Merge pull request #55 from github/mnist-2
Additional layer types
2 parents f247891 + f4ee2f8 commit 4ba18f2

19 files changed

+758
-175
lines changed

neurallogic/hard_and.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class HardAndLayer(nn.Module):
8989
def __call__(self, x):
9090
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
9191
weights = self.param(
92-
'bit_weights', nn.initializers.constant(0.0), weights_shape)
92+
'bit_weights', nn.initializers.constant(True), weights_shape)
9393
return hard_and_layer(weights, x)
9494

9595

@@ -105,5 +105,5 @@ def __call__(self, x):
105105

106106
and_layer = neural_logic_net.select(
107107
lambda layer_size, weights_init=initialize_near_to_zero(), dtype=jax.numpy.float32: SoftAndLayer(layer_size, weights_init, dtype),
108-
lambda layer_size, weights_init=initialize_near_to_zero(), dtype=jax.numpy.float32: HardAndLayer(layer_size),
109-
lambda layer_size, weights_init=initialize_near_to_zero(), dtype=jax.numpy.float32: SymbolicAndLayer(layer_size))
108+
lambda layer_size, weights_init=nn.initializers.constant(True), dtype=jax.numpy.float32: HardAndLayer(layer_size),
109+
lambda layer_size, weights_init=nn.initializers.constant(True), dtype=jax.numpy.float32: SymbolicAndLayer(layer_size))

neurallogic/hard_majority.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import jax
2+
from flax import linen as nn
3+
4+
from neurallogic import neural_logic_net, symbolic_generation
5+
6+
7+
def majority_index(input_size: int) -> int:
8+
return (input_size - 1) // 2
9+
10+
11+
def soft_majority(x: jax.numpy.array) -> float:
12+
index = majority_index(x.shape[-1])
13+
sorted_x = jax.numpy.sort(x, axis=-1)
14+
return jax.numpy.take(sorted_x, index, axis=-1)
15+
16+
17+
def hard_majority(x: jax.numpy.array) -> bool:
18+
threshold = x.shape[-1] - majority_index(x.shape[-1])
19+
return jax.numpy.sum(x, axis=-1) >= threshold
20+
21+
22+
soft_majority_layer = jax.vmap(soft_majority, in_axes=0)
23+
24+
hard_majority_layer = jax.vmap(hard_majority, in_axes=0)
25+
26+
27+
class SoftMajorityLayer(nn.Module):
28+
"""
29+
A soft-bit MAJORITY layer than transforms its inputs along the last dimension.
30+
31+
Attributes:
32+
layer_size: The number of neurons in the layer.
33+
weights_init: The initializer function for the weight matrix.
34+
"""
35+
@nn.compact
36+
def __call__(self, x):
37+
return soft_majority_layer(x)
38+
39+
40+
class HardMajorityLayer(nn.Module):
41+
@nn.compact
42+
def __call__(self, x):
43+
return hard_majority_layer(x)
44+
45+
46+
class SymbolicMajorityLayer:
47+
def __init__(self):
48+
self.hard_majority_layer = HardMajorityLayer()
49+
50+
def __call__(self, x):
51+
jaxpr = symbolic_generation.make_symbolic_flax_jaxpr(
52+
self.hard_majority_layer, x)
53+
return symbolic_generation.symbolic_expression(jaxpr, x)
54+
55+
56+
majority_layer = neural_logic_net.select(
57+
lambda: SoftMajorityLayer(),
58+
lambda: HardMajorityLayer(),
59+
lambda: SymbolicMajorityLayer()
60+
)

neurallogic/hard_not.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ def soft_not(w: float, x: float) -> float:
1919

2020

2121
def hard_not(w: bool, x: bool) -> bool:
22-
# ~(x ^ w)
2322
return jax.numpy.logical_not(jax.numpy.logical_xor(x, w))
2423

2524

@@ -48,11 +47,12 @@ def __call__(self, x):
4847

4948
class HardNotLayer(nn.Module):
5049
layer_size: int
50+
weights_init: Callable = nn.initializers.constant(True)
5151

5252
@nn.compact
5353
def __call__(self, x):
5454
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
55-
weights = self.param("bit_weights", nn.initializers.constant(0.0), weights_shape)
55+
weights = self.param("bit_weights", self.weights_init, weights_shape)
5656
return hard_not_layer(weights, x)
5757

5858

@@ -67,13 +67,7 @@ def __call__(self, x):
6767

6868

6969
not_layer = neural_logic_net.select(
70-
lambda layer_size, weights_init=nn.initializers.uniform(
71-
1.0
72-
), dtype=jax.numpy.float32: SoftNotLayer(layer_size, weights_init, dtype),
73-
lambda layer_size, weights_init=nn.initializers.uniform(
74-
1.0
75-
), dtype=jax.numpy.float32: HardNotLayer(layer_size),
76-
lambda layer_size, weights_init=nn.initializers.uniform(
77-
1.0
78-
), dtype=jax.numpy.float32: SymbolicNotLayer(layer_size),
70+
lambda layer_size, weights_init=nn.initializers.uniform(1.0), dtype=jax.numpy.float32: SoftNotLayer(layer_size, weights_init, dtype),
71+
lambda layer_size, weights_init=nn.initializers.uniform(1.0), dtype=jax.numpy.float32: HardNotLayer(layer_size),
72+
lambda layer_size, weights_init=nn.initializers.uniform(1.0), dtype=jax.numpy.float32: SymbolicNotLayer(layer_size),
7973
)

neurallogic/hard_or.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from functools import reduce
21
from typing import Callable
32

43
import jax
@@ -12,7 +11,7 @@ def soft_or_include(w: float, x: float) -> float:
1211
w > 0.5 implies the and operation is active, else inactive
1312
1413
Assumes x is in [0, 1]
15-
14+
1615
Corresponding hard logic: b AND w
1716
"""
1817
w = jax.numpy.clip(w, 0.0, 1.0)
@@ -27,6 +26,7 @@ def soft_or_neuron(w, x):
2726
x = jax.vmap(soft_or_include, 0, 0)(w, x)
2827
return jax.numpy.max(x)
2928

29+
3030
def hard_or_neuron(w, x):
3131
x = jax.vmap(hard_or_include, 0, 0)(w, x)
3232
return jax.lax.reduce(x, False, jax.lax.bitwise_or, [0])
@@ -37,6 +37,8 @@ def hard_or_neuron(w, x):
3737
hard_or_layer = jax.vmap(hard_or_neuron, (0, None), 0)
3838

3939
# TODO: investigate better initialization
40+
41+
4042
def initialize_near_to_one():
4143
def init(key, shape, dtype):
4244
dtype = jax.dtypes.canonicalize_dtype(dtype)
@@ -48,6 +50,7 @@ def init(key, shape, dtype):
4850
return x
4951
return init
5052

53+
5154
class SoftOrLayer(nn.Module):
5255
layer_size: int
5356
weights_init: Callable = initialize_near_to_one()
@@ -56,29 +59,37 @@ class SoftOrLayer(nn.Module):
5659
@nn.compact
5760
def __call__(self, x):
5861
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
59-
weights = self.param('bit_weights', self.weights_init, weights_shape, self.dtype)
60-
x = jax.numpy.asarray(x, self.dtype)
62+
weights = self.param(
63+
'bit_weights', self.weights_init, weights_shape, self.dtype)
64+
x = jax.numpy.asarray(x, self.dtype)
6165
return soft_or_layer(weights, x)
6266

67+
6368
class HardOrLayer(nn.Module):
6469
layer_size: int
6570

6671
@nn.compact
6772
def __call__(self, x):
6873
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
69-
weights = self.param('bit_weights', nn.initializers.constant(0.0), weights_shape)
74+
weights = self.param(
75+
'bit_weights', nn.initializers.constant(True), weights_shape)
7076
return hard_or_layer(weights, x)
7177

78+
7279
class SymbolicOrLayer:
7380
def __init__(self, layer_size):
7481
self.layer_size = layer_size
7582
self.hard_or_layer = HardOrLayer(self.layer_size)
7683

7784
def __call__(self, x):
78-
jaxpr = symbolic_generation.make_symbolic_flax_jaxpr(self.hard_or_layer, x)
85+
jaxpr = symbolic_generation.make_symbolic_flax_jaxpr(
86+
self.hard_or_layer, x)
7987
return symbolic_generation.symbolic_expression(jaxpr, x)
8088

89+
8190
or_layer = neural_logic_net.select(
82-
lambda layer_size, weights_init=initialize_near_to_one(), dtype=jax.numpy.float32: SoftOrLayer(layer_size, weights_init, dtype),
83-
lambda layer_size, weights_init=initialize_near_to_one(), dtype=jax.numpy.float32: HardOrLayer(layer_size),
84-
lambda layer_size, weights_init=initialize_near_to_one(), dtype=jax.numpy.float32: SymbolicOrLayer(layer_size))
91+
lambda layer_size, weights_init=initialize_near_to_one(
92+
), dtype=jax.numpy.float32: SoftOrLayer(layer_size, weights_init, dtype),
93+
lambda layer_size, weights_init=nn.initializers.constant(
94+
True), dtype=jax.numpy.float32: HardOrLayer(layer_size),
95+
lambda layer_size, weights_init=nn.initializers.constant(True), dtype=jax.numpy.float32: SymbolicOrLayer(layer_size))

neurallogic/hard_xor.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from typing import Callable
2+
3+
import jax
4+
from flax import linen as nn
5+
6+
from neurallogic import neural_logic_net, symbolic_generation
7+
8+
9+
def soft_xor_include(w: float, x: float) -> float:
10+
"""
11+
w > 0.5 implies the and operation is active, else inactive
12+
13+
Assumes x is in [0, 1]
14+
15+
Corresponding hard logic: b AND w
16+
"""
17+
w = jax.numpy.clip(w, 0.0, 1.0)
18+
return 1.0 - jax.numpy.maximum(1.0 - x, 1.0 - w)
19+
20+
21+
def hard_xor_include(w, x):
22+
return jax.numpy.logical_and(x, w)
23+
24+
25+
def soft_xor_neuron(w, x):
26+
# Conditionally include input bits, according to weights
27+
x = jax.vmap(soft_xor_include, 0, 0)(w, x)
28+
29+
def xor(x, y):
30+
return jax.numpy.minimum(jax.numpy.maximum(x, y), 1.0 - jax.numpy.minimum(x, y))
31+
x = jax.lax.reduce(x, jax.numpy.array(0, dtype=x.dtype), xor, (0,))
32+
return x
33+
34+
35+
def hard_xor_neuron(w, x):
36+
x = jax.vmap(hard_xor_include, 0, 0)(w, x)
37+
return jax.lax.reduce(x, False, jax.lax.bitwise_xor, [0])
38+
39+
40+
soft_xor_layer = jax.vmap(soft_xor_neuron, (0, None), 0)
41+
42+
43+
hard_xor_layer = jax.vmap(hard_xor_neuron, (0, None), 0)
44+
45+
46+
class SoftXorLayer(nn.Module):
47+
layer_size: int
48+
weights_init: Callable = nn.initializers.uniform(1.0)
49+
dtype: jax.numpy.dtype = jax.numpy.float32
50+
51+
@nn.compact
52+
def __call__(self, x):
53+
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
54+
weights = self.param(
55+
'bit_weights', self.weights_init, weights_shape, self.dtype)
56+
x = jax.numpy.asarray(x, self.dtype)
57+
return soft_xor_layer(weights, x)
58+
59+
60+
class HardXorLayer(nn.Module):
61+
layer_size: int
62+
63+
@nn.compact
64+
def __call__(self, x):
65+
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
66+
weights = self.param(
67+
'bit_weights', nn.initializers.constant(True), weights_shape)
68+
return hard_xor_layer(weights, x)
69+
70+
71+
class SymbolicXorLayer:
72+
def __init__(self, layer_size):
73+
self.layer_size = layer_size
74+
self.hard_xor_layer = HardXorLayer(self.layer_size)
75+
76+
def __call__(self, x):
77+
jaxpr = symbolic_generation.make_symbolic_flax_jaxpr(
78+
self.hard_xor_layer, x)
79+
return symbolic_generation.symbolic_expression(jaxpr, x)
80+
81+
82+
xor_layer = neural_logic_net.select(
83+
lambda layer_size, weights_init=nn.initializers.uniform(
84+
1.0), dtype=jax.numpy.float32: SoftXorLayer(layer_size, weights_init, dtype),
85+
lambda layer_size, weights_init=nn.initializers.constant(
86+
True), dtype=jax.numpy.float32: HardXorLayer(layer_size),
87+
lambda layer_size, weights_init=nn.initializers.constant(True), dtype=jax.numpy.float32: SymbolicXorLayer(layer_size))

neurallogic/harden.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,27 @@ def harden(x: float):
2020
return harden_float(x)
2121

2222

23+
@dispatch
24+
def harden(x: bool):
25+
return x
26+
27+
2328
@dispatch
2429
def harden(x: list):
2530
return map_at_elements.map_at_elements(x, harden_float)
2631

2732

2833
@dispatch
2934
def harden(x: numpy.ndarray):
35+
if x.ndim == 0:
36+
return harden(x.item())
3037
return harden_array(x)
3138

3239

3340
@dispatch
3441
def harden(x: jax.numpy.ndarray):
42+
if x.ndim == 0:
43+
return harden(x.item())
3544
return harden_array(x)
3645

3746

@@ -53,13 +62,6 @@ def harden(x: flax.core.FrozenDict):
5362
return harden(x.unfreeze())
5463

5564

56-
@dispatch
57-
def harden(*args):
58-
if len(args) == 1:
59-
return harden(args[0])
60-
return tuple([harden(arg) for arg in args])
61-
62-
6365
@dispatch
6466
def map_keys_nested(f, d: dict) -> dict:
6567
return {

neurallogic/harden_layer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ def straight_through_harden_element(x):
1818
def hard_harden_layer(x):
1919
return x
2020

21+
#TODO: can we harden arbitrary tensors?
22+
#TODO: is this correct?
2123
def symbolic_harden_layer(x):
2224
return x
2325

neurallogic/map_at_elements.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ def map_at_elements(x: numpy.float32, func: typing.Callable):
3030
return func(x)
3131

3232

33+
@dispatch
34+
def map_at_elements(x: numpy.int32, func: typing.Callable):
35+
return func(x)
36+
37+
3338
@dispatch
3439
def map_at_elements(x: list, func: typing.Callable):
3540
return [map_at_elements(item, func) for item in x]
@@ -55,4 +60,3 @@ def map_at_elements(x: dict, func: typing.Callable):
5560
@dispatch
5661
def map_at_elements(x: tuple, func: typing.Callable):
5762
return tuple(map_at_elements(list(x), func))
58-

neurallogic/real_encoder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from neurallogic import neural_logic_net, symbolic_generation
77

8+
# TODO: perhaps this can be simplified with a simple multiplication?
9+
# TODO: implement a soft_real_decoder that can perhaps replace the port count approach
810

911
def soft_real_encoder(t: float, x: float) -> float:
1012
eps = 0.0000001

0 commit comments

Comments
 (0)