File tree Expand file tree Collapse file tree 1 file changed +5
-4
lines changed Expand file tree Collapse file tree 1 file changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -97,10 +97,11 @@ def make_symbolic_flax_jaxpr(flax_layer, x):
97
97
x = numpy .asarray (x , dtype = numpy .int32 )
98
98
# Make the jaxpr that corresponds to the flax layer
99
99
jaxpr = make_symbolic_jaxpr (flax_layer , x )
100
- # Make a list of bit_weights and thresholds but only include each if they are not None
101
- bit_weights_and_thresholds = [x for x in [bit_weights , thresholds ] if x is not None ]
102
- # Replace the dummy numeric weights with the actual weights in the jaxpr
103
- jaxpr .consts = bit_weights_and_thresholds
100
+ if hasattr (jaxpr , '_consts' ):
101
+ # Make a list of bit_weights and thresholds but only include each if they are not None
102
+ bit_weights_and_thresholds = [x for x in [bit_weights , thresholds ] if x is not None ]
103
+ # Replace the dummy numeric weights with the actual weights in the jaxpr
104
+ jaxpr .__setattr__ ('_consts' , bit_weights_and_thresholds )
104
105
return jaxpr
105
106
106
107
You can’t perform that action at this time.
0 commit comments