Skip to content

Commit e3bcc5b

Browse files
committed
conditionally harden weights
1 parent 948cca5 commit e3bcc5b

File tree

8 files changed

+87
-89
lines changed

8 files changed

+87
-89
lines changed

neurallogic/hard_and.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import Any
22

3-
import numpy
43
import jax
54
from flax import linen as nn
65
from typing import Callable
@@ -70,7 +69,7 @@ class SoftAndLayer(nn.Module):
7069
@nn.compact
7170
def __call__(self, x):
7271
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
73-
weights = self.param('weights', self.weights_init,
72+
weights = self.param('bit_weights', self.weights_init,
7473
weights_shape, self.dtype)
7574
x = jax.numpy.asarray(x, self.dtype)
7675
return soft_and_layer(weights, x)
@@ -90,7 +89,7 @@ class HardAndLayer(nn.Module):
9089
def __call__(self, x):
9190
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
9291
weights = self.param(
93-
'weights', nn.initializers.constant(0.0), weights_shape)
92+
'bit_weights', nn.initializers.constant(0.0), weights_shape)
9493
return hard_and_layer(weights, x)
9594

9695

neurallogic/hard_not.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class SoftNotLayer(nn.Module):
4141
@nn.compact
4242
def __call__(self, x):
4343
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
44-
weights = self.param("weights", self.weights_init, weights_shape, self.dtype)
44+
weights = self.param("bit_weights", self.weights_init, weights_shape, self.dtype)
4545
x = jax.numpy.asarray(x, self.dtype)
4646
return soft_not_layer(weights, x)
4747

@@ -52,7 +52,7 @@ class HardNotLayer(nn.Module):
5252
@nn.compact
5353
def __call__(self, x):
5454
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
55-
weights = self.param("weights", nn.initializers.constant(0.0), weights_shape)
55+
weights = self.param("bit_weights", nn.initializers.constant(0.0), weights_shape)
5656
return hard_not_layer(weights, x)
5757

5858

neurallogic/hard_or.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class SoftOrLayer(nn.Module):
5656
@nn.compact
5757
def __call__(self, x):
5858
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
59-
weights = self.param('weights', self.weights_init, weights_shape, self.dtype)
59+
weights = self.param('bit_weights', self.weights_init, weights_shape, self.dtype)
6060
x = jax.numpy.asarray(x, self.dtype)
6161
return soft_or_layer(weights, x)
6262

@@ -66,7 +66,7 @@ class HardOrLayer(nn.Module):
6666
@nn.compact
6767
def __call__(self, x):
6868
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
69-
weights = self.param('weights', nn.initializers.constant(0.0), weights_shape)
69+
weights = self.param('bit_weights', nn.initializers.constant(0.0), weights_shape)
7070
return hard_or_layer(weights, x)
7171

7272
class SymbolicOrLayer:

neurallogic/harden.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,20 @@ def harden(x: jax.numpy.ndarray):
3636

3737
@dispatch
3838
def harden(x: dict):
39-
return symbolic_primitives.map_at_elements(x, harden)
39+
# Only harden parameters that explicitly represent bits
40+
def conditional_harden(k, v):
41+
if k.startswith("bit_"):
42+
return symbolic_primitives.map_at_elements(v, harden)
43+
elif isinstance(v, dict) or isinstance(v, flax.core.FrozenDict) or isinstance(v, list):
44+
return harden(v)
45+
return v
46+
47+
return {k: conditional_harden(k, v) for k, v in x.items()}
4048

4149

4250
@dispatch
4351
def harden(x: flax.core.FrozenDict):
44-
return flax.core.FrozenDict(
45-
symbolic_primitives.map_at_elements(x.unfreeze(), harden)
46-
)
52+
return harden(x.unfreeze())
4753

4854

4955
@dispatch

neurallogic/real_encoder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __call__(self, x):
6666
jaxpr = symbolic_generation.make_symbolic_flax_jaxpr(
6767
self.hard_real_encoder_layer, x
6868
)
69+
print(f'SymbolicRealEncoderLayer: jaxpr:\n{jaxpr}')
6970
return symbolic_generation.symbolic_expression(jaxpr, x)
7071

7172

neurallogic/symbolic_generation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def put_variable(self, col: str, name: str, value: Any):
6060

6161

6262
def make_symbolic_flax_jaxpr(flax_layer, x):
63-
actual_weights = flax_layer.get_variable("params", "weights")
63+
actual_weights = flax_layer.get_variable("params", "bit_weights")
6464
# Convert actual weights to dummy numeric weights (if needed)
6565
if isinstance(actual_weights, list) or (
6666
isinstance(actual_weights, numpy.ndarray) and actual_weights.dtype == object
@@ -69,7 +69,7 @@ def make_symbolic_flax_jaxpr(flax_layer, x):
6969
actual_weights, lambda x: 0
7070
)
7171
numeric_weights = numpy.asarray(numeric_weights, dtype=numpy.int32)
72-
put_variable(flax_layer, "params", "weights", numeric_weights)
72+
put_variable(flax_layer, "params", "bit_weights", numeric_weights)
7373
# Convert input to dummy numeric input (if needed)
7474
if isinstance(x, list) or (isinstance(x, numpy.ndarray) and x.dtype == object):
7575
x = symbolic_primitives.map_at_elements(x, lambda x: 0)
@@ -171,9 +171,9 @@ def eval_jaxpr_impl(jaxpr):
171171
symbolic_outvals = [symbolic_outvals]
172172
if not symbolic:
173173
# Check that the concrete and symbolic values are equal
174-
#print(
174+
# print(
175175
# f"outvals: {outvals} and symbolic_outvals: {symbolic_outvals}"
176-
#)
176+
# )
177177
assert numpy.allclose(
178178
numpy.array(outvals), symbolic_outvals, equal_nan=True
179179
)

0 commit comments

Comments
 (0)