@@ -59,8 +59,11 @@ def put_variable(self, col: str, name: str, value: Any):
59
59
scope_put_variable (self .scope , col , name , value )
60
60
61
61
62
- def make_symbolic_flax_jaxpr (flax_layer , x ):
63
- actual_weights = flax_layer .get_variable ("params" , "bit_weights" )
62
+ # TODO: make this robust and general over multiple types of param names
63
+
64
+
65
+ def convert_to_numeric_params (flax_layer , param_names : str ):
66
+ actual_weights = flax_layer .get_variable ("params" , param_names )
64
67
# Convert actual weights to dummy numeric weights (if needed)
65
68
if isinstance (actual_weights , list ) or (
66
69
isinstance (actual_weights , numpy .ndarray ) and actual_weights .dtype == object
@@ -69,15 +72,23 @@ def make_symbolic_flax_jaxpr(flax_layer, x):
69
72
actual_weights , lambda x : 0
70
73
)
71
74
numeric_weights = numpy .asarray (numeric_weights , dtype = numpy .int32 )
72
- put_variable (flax_layer , "params" , "bit_weights" , numeric_weights )
75
+ put_variable (flax_layer , "params" , param_names , numeric_weights )
76
+ return flax_layer , actual_weights
77
+
78
+
79
+ def make_symbolic_flax_jaxpr (flax_layer , x ):
80
+ flax_layer , bit_weights = convert_to_numeric_params (flax_layer , "bit_weights" )
81
+ flax_layer , thresholds = convert_to_numeric_params (flax_layer , "thresholds" )
73
82
# Convert input to dummy numeric input (if needed)
74
83
if isinstance (x , list ) or (isinstance (x , numpy .ndarray ) and x .dtype == object ):
75
84
x = symbolic_primitives .map_at_elements (x , lambda x : 0 )
76
85
x = numpy .asarray (x , dtype = numpy .int32 )
77
86
# Make the jaxpr that corresponds to the flax layer
78
87
jaxpr = make_symbolic_jaxpr (flax_layer , x )
88
+ # Make a list of bit_weights and thresholds but only include each if they are not None
89
+ bit_weights_and_thresholds = [x for x in [bit_weights , thresholds ] if x is not None ]
79
90
# Replace the dummy numeric weights with the actual weights in the jaxpr
80
- jaxpr .consts = [ actual_weights ]
91
+ jaxpr .consts = bit_weights_and_thresholds
81
92
return jaxpr
82
93
83
94
0 commit comments