Skip to content

Commit 853a291

Browse files
committed
temporary handling of thresholds
1 parent e3bcc5b commit 853a291

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

neurallogic/real_encoder.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ def __call__(self, x):
6666
jaxpr = symbolic_generation.make_symbolic_flax_jaxpr(
6767
self.hard_real_encoder_layer, x
6868
)
69-
print(f'SymbolicRealEncoderLayer: jaxpr:\n{jaxpr}')
7069
return symbolic_generation.symbolic_expression(jaxpr, x)
7170

7271

neurallogic/symbolic_generation.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,11 @@ def put_variable(self, col: str, name: str, value: Any):
5959
scope_put_variable(self.scope, col, name, value)
6060

6161

62-
def make_symbolic_flax_jaxpr(flax_layer, x):
63-
actual_weights = flax_layer.get_variable("params", "bit_weights")
62+
# TODO: make this robust and general over multiple types of param names
63+
64+
65+
def convert_to_numeric_params(flax_layer, param_names: str):
66+
actual_weights = flax_layer.get_variable("params", param_names)
6467
# Convert actual weights to dummy numeric weights (if needed)
6568
if isinstance(actual_weights, list) or (
6669
isinstance(actual_weights, numpy.ndarray) and actual_weights.dtype == object
@@ -69,15 +72,23 @@ def make_symbolic_flax_jaxpr(flax_layer, x):
6972
actual_weights, lambda x: 0
7073
)
7174
numeric_weights = numpy.asarray(numeric_weights, dtype=numpy.int32)
72-
put_variable(flax_layer, "params", "bit_weights", numeric_weights)
75+
put_variable(flax_layer, "params", param_names, numeric_weights)
76+
return flax_layer, actual_weights
77+
78+
79+
def make_symbolic_flax_jaxpr(flax_layer, x):
80+
flax_layer, bit_weights = convert_to_numeric_params(flax_layer, "bit_weights")
81+
flax_layer, thresholds = convert_to_numeric_params(flax_layer, "thresholds")
7382
# Convert input to dummy numeric input (if needed)
7483
if isinstance(x, list) or (isinstance(x, numpy.ndarray) and x.dtype == object):
7584
x = symbolic_primitives.map_at_elements(x, lambda x: 0)
7685
x = numpy.asarray(x, dtype=numpy.int32)
7786
# Make the jaxpr that corresponds to the flax layer
7887
jaxpr = make_symbolic_jaxpr(flax_layer, x)
88+
# Make a list of bit_weights and thresholds but only include each if they are not None
89+
bit_weights_and_thresholds = [x for x in [bit_weights, thresholds] if x is not None]
7990
# Replace the dummy numeric weights with the actual weights in the jaxpr
80-
jaxpr.consts = [actual_weights]
91+
jaxpr.consts = bit_weights_and_thresholds
8192
return jaxpr
8293

8394

0 commit comments

Comments
 (0)