Skip to content

Commit 101d5cd

Browse files
committed
fix my hack after jax upgrade
1 parent 653afa8 commit 101d5cd

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

neurallogic/symbolic_generation.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,11 @@ def make_symbolic_flax_jaxpr(flax_layer, x):
9797
x = numpy.asarray(x, dtype=numpy.int32)
9898
# Make the jaxpr that corresponds to the flax layer
9999
jaxpr = make_symbolic_jaxpr(flax_layer, x)
100-
# Make a list of bit_weights and thresholds but only include each if they are not None
101-
bit_weights_and_thresholds = [x for x in [bit_weights, thresholds] if x is not None]
102-
# Replace the dummy numeric weights with the actual weights in the jaxpr
103-
jaxpr.consts = bit_weights_and_thresholds
100+
if hasattr(jaxpr, '_consts'):
101+
# Make a list of bit_weights and thresholds but only include each if they are not None
102+
bit_weights_and_thresholds = [x for x in [bit_weights, thresholds] if x is not None]
103+
# Replace the dummy numeric weights with the actual weights in the jaxpr
104+
jaxpr.__setattr__('_consts', bit_weights_and_thresholds)
104105
return jaxpr
105106

106107

0 commit comments

Comments
 (0)