Skip to content

Commit 0f0ed86

Browse files
authored
Merge pull request #57 from github/noisy-xor
Infrastructure for experiments
2 parents 27a4d9f + 48d58e1 commit 0f0ed86

16 files changed

+11333
-522
lines changed

neurallogic/hard_and.py

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,19 @@
1-
from typing import Any
1+
from typing import Callable
22

33
import jax
44
from flax import linen as nn
5-
from typing import Callable
6-
7-
8-
from neurallogic import neural_logic_net, symbolic_generation
9-
10-
11-
def soft_and_include(w: float, x: float) -> float:
12-
"""
13-
w > 0.5 implies the and operation is active, else inactive
14-
15-
Assumes x is in [0, 1]
16-
17-
Corresponding hard logic: x OR ! w
18-
"""
19-
w = jax.numpy.clip(w, 0.0, 1.0)
20-
return jax.numpy.maximum(x, 1.0 - w)
21-
22-
23-
24-
def hard_and_include(w, x):
25-
return jax.numpy.logical_or(x, jax.numpy.logical_not(w))
265

6+
from neurallogic import hard_masks, neural_logic_net, symbolic_generation
277

288

9+
# TODO: seperate and operation from mask operation
2910
def soft_and_neuron(w, x):
30-
x = jax.vmap(soft_and_include, 0, 0)(w, x)
11+
x = jax.vmap(hard_masks.soft_mask_to_true, 0, 0)(w, x)
3112
return jax.numpy.min(x)
3213

3314

3415
def hard_and_neuron(w, x):
35-
x = jax.vmap(hard_and_include, 0, 0)(w, x)
16+
x = jax.vmap(hard_masks.hard_mask_to_true, 0, 0)(w, x)
3617
return jax.lax.reduce(x, True, jax.lax.bitwise_and, [0])
3718

3819

@@ -41,6 +22,7 @@ def hard_and_neuron(w, x):
4122
hard_and_layer = jax.vmap(hard_and_neuron, (0, None), 0)
4223

4324

25+
# TODO: move initialization to separate file
4426
def initialize_near_to_zero():
4527
# TODO: investigate better initialization
4628
def init(key, shape, dtype):
@@ -51,6 +33,7 @@ def init(key, shape, dtype):
5133
x = 0.5 * x - 1
5234
x = jax.numpy.clip(x, 0.001, 0.999)
5335
return x
36+
5437
return init
5538

5639

@@ -62,15 +45,17 @@ class SoftAndLayer(nn.Module):
6245
layer_size: The number of neurons in the layer.
6346
weights_init: The initializer function for the weight matrix.
6447
"""
48+
6549
layer_size: int
6650
weights_init: Callable = initialize_near_to_zero()
6751
dtype: jax.numpy.dtype = jax.numpy.float32
6852

6953
@nn.compact
7054
def __call__(self, x):
7155
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
72-
weights = self.param('bit_weights', self.weights_init,
73-
weights_shape, self.dtype)
56+
weights = self.param(
57+
"bit_weights", self.weights_init, weights_shape, self.dtype
58+
)
7459
x = jax.numpy.asarray(x, self.dtype)
7560
return soft_and_layer(weights, x)
7661

@@ -83,13 +68,15 @@ class HardAndLayer(nn.Module):
8368
Attributes:
8469
layer_size: The number of neurons in the layer.
8570
"""
71+
8672
layer_size: int
8773

8874
@nn.compact
8975
def __call__(self, x):
9076
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
9177
weights = self.param(
92-
'bit_weights', nn.initializers.constant(True), weights_shape)
78+
"bit_weights", nn.initializers.constant(True), weights_shape
79+
)
9380
return hard_and_layer(weights, x)
9481

9582

@@ -104,6 +91,13 @@ def __call__(self, x):
10491

10592

10693
and_layer = neural_logic_net.select(
107-
lambda layer_size, weights_init=initialize_near_to_zero(), dtype=jax.numpy.float32: SoftAndLayer(layer_size, weights_init, dtype),
108-
lambda layer_size, weights_init=nn.initializers.constant(True), dtype=jax.numpy.float32: HardAndLayer(layer_size),
109-
lambda layer_size, weights_init=nn.initializers.constant(True), dtype=jax.numpy.float32: SymbolicAndLayer(layer_size))
94+
lambda layer_size, weights_init=initialize_near_to_zero(), dtype=jax.numpy.float32: SoftAndLayer(
95+
layer_size, weights_init, dtype
96+
),
97+
lambda layer_size, weights_init=nn.initializers.constant(
98+
True
99+
), dtype=jax.numpy.float32: HardAndLayer(layer_size),
100+
lambda layer_size, weights_init=nn.initializers.constant(
101+
True
102+
), dtype=jax.numpy.float32: SymbolicAndLayer(layer_size),
103+
)

neurallogic/hard_dropout.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from typing import Optional, Sequence
2+
3+
import jax
4+
from flax import linen as nn
5+
from jax import lax, random
6+
7+
from neurallogic import neural_logic_net
8+
9+
10+
class SoftHardDropout(nn.Module):
11+
"""Create a dropout layer suitable for dropping soft-bit values.
12+
Adapted from flax/stochastic.py
13+
14+
15+
Note: When using :meth:`Module.apply() <flax.linen.Module.apply>`, make sure
16+
to include an RNG seed named `'dropout'`. For example::
17+
18+
model.apply({'params': params}, inputs=inputs, train=True, rngs={'dropout': dropout_rng})`
19+
20+
Attributes:
21+
rate: the dropout probability. (_not_ the keep rate!)
22+
broadcast_dims: dimensions that will share the same dropout mask
23+
deterministic: if false the inputs are scaled by `1 / (1 - rate)` and
24+
masked, whereas if true, no mask is applied and the inputs are returned
25+
as is.
26+
rng_collection: the rng collection name to use when requesting an rng key.
27+
"""
28+
29+
rate: float
30+
broadcast_dims: Sequence[int] = ()
31+
deterministic: Optional[bool] = None
32+
rng_collection: str = "dropout"
33+
dropout_value: float = 0.0
34+
dtype: jax.numpy.dtype = jax.numpy.float32
35+
36+
@nn.compact
37+
def __call__(self, inputs, deterministic: Optional[bool] = None):
38+
"""Applies a random dropout mask to the input.
39+
40+
Args:
41+
inputs: the inputs that should be randomly masked.
42+
Masking means setting the input bits to 0.5.
43+
deterministic: if false the inputs are masked,
44+
whereas if true, no mask is applied and the inputs are returned
45+
as is.
46+
47+
Returns:
48+
The masked inputs
49+
"""
50+
deterministic = nn.merge_param(
51+
"deterministic", self.deterministic, deterministic
52+
)
53+
54+
if (self.rate == 0.0) or deterministic:
55+
return inputs
56+
57+
# Prevent gradient NaNs in 1.0 edge-case.
58+
if self.rate == 1.0:
59+
return jax.numpy.zeros_like(inputs)
60+
61+
keep_prob = 1.0 - self.rate
62+
rng = self.make_rng(self.rng_collection)
63+
broadcast_shape = list(inputs.shape)
64+
for dim in self.broadcast_dims:
65+
broadcast_shape[dim] = 1
66+
mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
67+
mask = jax.numpy.broadcast_to(mask, inputs.shape)
68+
masked_values = jax.numpy.full_like(
69+
inputs, self.dropout_value, dtype=self.dtype
70+
)
71+
return lax.select(mask, inputs, masked_values)
72+
73+
74+
class HardHardDropout(nn.Module):
75+
@nn.compact
76+
def __call__(self, inputs, deterministic: Optional[bool] = None):
77+
return inputs
78+
79+
80+
class SymbolicHardDropout(nn.Module):
81+
@nn.compact
82+
def __call__(self, inputs, deterministic: Optional[bool] = None):
83+
return inputs
84+
85+
86+
hard_dropout = neural_logic_net.select(
87+
lambda **kwargs: SoftHardDropout(**kwargs),
88+
lambda **kwargs: HardHardDropout(**kwargs),
89+
lambda **kwargs: SymbolicHardDropout(**kwargs),
90+
)

neurallogic/hard_masks.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
from typing import Callable
2+
3+
import jax
4+
from flax import linen as nn
5+
6+
from neurallogic import neural_logic_net, symbolic_generation
7+
8+
9+
def soft_mask_to_true(w: float, x: float) -> float:
10+
"""
11+
w > 0.5 implies the mask operation is inactive, else active
12+
13+
Assumes x is in [0, 1]
14+
15+
Corresponding hard logic: x OR ! w
16+
"""
17+
w = jax.numpy.clip(w, 0.0, 1.0)
18+
return jax.numpy.maximum(x, 1.0 - w)
19+
20+
21+
def hard_mask_to_true(w, x):
22+
return jax.numpy.logical_or(x, jax.numpy.logical_not(w))
23+
24+
25+
soft_mask_to_true_neuron = jax.vmap(soft_mask_to_true, 0, 0)
26+
27+
hard_mask_to_true_neuron = jax.vmap(hard_mask_to_true, 0, 0)
28+
29+
30+
soft_mask_to_true_layer = jax.vmap(soft_mask_to_true_neuron, (0, None), 0)
31+
32+
hard_mask_to_true_layer = jax.vmap(hard_mask_to_true_neuron, (0, None), 0)
33+
34+
35+
def soft_mask_to_false(w: float, x: float) -> float:
36+
"""
37+
w > 0.5 implies the mask is inactive, else active
38+
39+
Assumes x is in [0, 1]
40+
41+
Corresponding hard logic: b AND w
42+
"""
43+
w = jax.numpy.clip(w, 0.0, 1.0)
44+
return 1.0 - jax.numpy.maximum(1.0 - x, 1.0 - w)
45+
46+
47+
def hard_mask_to_false(w, x):
48+
return jax.numpy.logical_and(x, w)
49+
50+
51+
soft_mask_to_false_neuron = jax.vmap(soft_mask_to_false, 0, 0)
52+
53+
hard_mask_to_false_neuron = jax.vmap(hard_mask_to_false, 0, 0)
54+
55+
56+
soft_mask_to_false_layer = jax.vmap(soft_mask_to_false_neuron, (0, None), 0)
57+
58+
hard_mask_to_false_layer = jax.vmap(hard_mask_to_false_neuron, (0, None), 0)
59+
60+
61+
class SoftMaskLayer(nn.Module):
62+
mask_layer_operation: Callable
63+
layer_size: int
64+
weights_init: Callable = nn.initializers.uniform(1.0)
65+
dtype: jax.numpy.dtype = jax.numpy.float32
66+
67+
@nn.compact
68+
def __call__(self, x):
69+
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
70+
weights = self.param(
71+
"bit_weights", self.weights_init, weights_shape, self.dtype
72+
)
73+
x = jax.numpy.asarray(x, self.dtype)
74+
return self.mask_layer_operation(weights, x)
75+
76+
77+
class HardMaskLayer(nn.Module):
78+
mask_layer_operation: Callable
79+
layer_size: int
80+
weights_init: Callable = nn.initializers.constant(True)
81+
82+
@nn.compact
83+
def __call__(self, x):
84+
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
85+
weights = self.param("bit_weights", self.weights_init, weights_shape)
86+
return self.mask_layer_operation(weights, x)
87+
88+
89+
class SymbolicMaskLayer:
90+
def __init__(self, mask_layer):
91+
self.hard_mask_layer = mask_layer
92+
93+
def __call__(self, x):
94+
jaxpr = symbolic_generation.make_symbolic_flax_jaxpr(self.hard_mask_layer, x)
95+
return symbolic_generation.symbolic_expression(jaxpr, x)
96+
97+
98+
mask_to_true_layer = neural_logic_net.select(
99+
lambda layer_size, weights_init=nn.initializers.uniform(
100+
1.0
101+
), dtype=jax.numpy.float32: SoftMaskLayer(
102+
soft_mask_to_true_layer, layer_size, weights_init, dtype
103+
),
104+
lambda layer_size, weights_init=nn.initializers.uniform(
105+
1.0
106+
), dtype=jax.numpy.float32: HardMaskLayer(hard_mask_to_true_layer, layer_size),
107+
lambda layer_size, weights_init=nn.initializers.uniform(
108+
1.0
109+
), dtype=jax.numpy.float32: SymbolicMaskLayer(
110+
HardMaskLayer(hard_mask_to_true_layer, layer_size)
111+
),
112+
)
113+
114+
115+
mask_to_false_layer = neural_logic_net.select(
116+
lambda layer_size, weights_init=nn.initializers.uniform(
117+
1.0
118+
), dtype=jax.numpy.float32: SoftMaskLayer(
119+
soft_mask_to_false_layer, layer_size, weights_init, dtype
120+
),
121+
lambda layer_size, weights_init=nn.initializers.uniform(
122+
1.0
123+
), dtype=jax.numpy.float32: HardMaskLayer(hard_mask_to_false_layer, layer_size),
124+
lambda layer_size, weights_init=nn.initializers.uniform(
125+
1.0
126+
), dtype=jax.numpy.float32: SymbolicMaskLayer(
127+
HardMaskLayer(hard_mask_to_false_layer, layer_size)
128+
),
129+
)

0 commit comments

Comments
 (0)