6
6
from plum import dispatch
7
7
8
8
9
- # TODO: remove me?
10
- def convert_element_type (x , dtype ):
11
- return x
12
- if dtype == numpy .int32 or dtype == numpy .int64 :
13
- #dtype = 'int'
14
- return x
15
- elif dtype == bool :
16
- #dtype = 'bool'
17
- #dtype = 'numpy.all'
18
- # Don't force to bool because sometimes x is a tensor that is used in a numpy.where clause
19
- return x
20
- elif dtype == numpy .float32 :
21
- #dtype = 'float'
22
- return x
23
- else :
24
- raise NotImplementedError (
25
- f'Symbolic conversion of type { type (x )} to { dtype } not implemented'
26
- )
27
-
28
- def convert (x ):
29
- return f'{ dtype } ({ x } )'
30
-
31
- return map_at_elements (x , convert )
32
-
33
-
34
9
# TODO: allow func callable to control the type of the numpy.array or jax.numpy.array
35
10
36
11
# map_at_elements should alter the elements but not the type of the container
@@ -88,49 +63,6 @@ def map_at_elements(x: tuple, func: typing.Callable):
88
63
return tuple (map_at_elements (list (x ), func ))
89
64
90
65
91
- '''
92
- @dispatch
93
- def to_boolean_value_string(x: bool):
94
- return 'True' if x else 'False'
95
-
96
-
97
- @dispatch
98
- def to_boolean_value_string(x: numpy.bool_):
99
- return 'True' if x else 'False'
100
-
101
-
102
- @dispatch
103
- def to_boolean_value_string(x: int):
104
- return 'True' if x >= 1 else 'False'
105
-
106
-
107
- @dispatch
108
- def to_boolean_value_string(x: float):
109
- return 'True' if x >= 1.0 else 'False'
110
-
111
-
112
- @dispatch
113
- def to_boolean_value_string(x: str):
114
- if x == '1' or x == '1.0' or x == 'True':
115
- return 'True'
116
- elif x == '0' or x == '0.0' or x == 'False':
117
- return 'False'
118
- else:
119
- return x
120
-
121
- @dispatch
122
- def to_numeric_value(x):
123
- if x == 'True' or x:
124
- return 1
125
- elif x == 'False' or not x:
126
- return 0
127
- elif isinstance(x, int) or isinstance(x, float):
128
- return x
129
- else:
130
- return 0
131
- '''
132
-
133
- # TODO: add tests, and handle more cases
134
66
@dispatch
135
67
def symbolic_representation (x : numpy .ndarray ):
136
68
return repr (x ).replace ('array' , 'numpy.array' ).replace ('\n ' , '' ).replace ('float32' , 'numpy.float32' ).replace ('\' ' , '' )
@@ -146,47 +78,69 @@ def symbolic_operator(operator: str, x: str) -> str:
146
78
return f'{ operator } ({ x } )' .replace ('\' ' , '' )
147
79
148
80
149
- # XXX
150
81
@dispatch
151
82
def symbolic_operator (operator : str , x : float , y : str ):
152
83
return symbolic_operator (operator , str (x ), y )
84
+
85
+
153
86
@dispatch
154
87
def symbolic_operator (operator : str , x : str , y : float ):
155
88
return symbolic_operator (operator , x , str (y ))
89
+
90
+
156
91
@dispatch
157
92
def symbolic_operator (operator : str , x : float , y : numpy .ndarray ):
158
93
return numpy .vectorize (symbolic_operator , otypes = [object ])(operator , x , y )
94
+
95
+
159
96
@dispatch
160
97
def symbolic_operator (operator : str , x : numpy .ndarray , y : numpy .ndarray ):
161
98
return numpy .vectorize (symbolic_operator , otypes = [object ])(operator , x , y )
99
+
100
+
162
101
@dispatch
163
102
def symbolic_operator (operator : str , x : str , y : str ):
164
103
return f'{ operator } ({ x } , { y } )' .replace ('\' ' , '' )
104
+
105
+
165
106
@dispatch
166
107
def symbolic_operator (operator : str , x : numpy .ndarray , y : float ):
167
108
return numpy .vectorize (symbolic_operator , otypes = [object ])(operator , x , y )
109
+
110
+
168
111
@dispatch
169
112
def symbolic_operator (operator : str , x : list , y : float ):
170
113
return numpy .vectorize (symbolic_operator , otypes = [object ])(operator , x , y )
114
+
115
+
171
116
@dispatch
172
117
def symbolic_operator (operator : str , x : list , y : list ):
173
118
return numpy .vectorize (symbolic_operator , otypes = [object ])(operator , x , y )
119
+
120
+
174
121
@dispatch
175
122
def symbolic_operator (operator : str , x : bool , y : str ):
176
123
return symbolic_operator (operator , str (x ), y )
124
+
125
+
177
126
@dispatch
178
127
def symbolic_operator (operator : str , x : str , y : numpy .ndarray ):
179
128
return numpy .vectorize (symbolic_operator , otypes = [object ])(operator , x , y )
129
+
130
+
180
131
@dispatch
181
132
def symbolic_operator (operator : str , x : str , y : int ):
182
133
return symbolic_operator (operator , x , str (y ))
134
+
135
+
183
136
@dispatch
184
137
def symbolic_operator (operator : str , x : list , y : numpy .ndarray ):
185
138
return numpy .vectorize (symbolic_operator , otypes = [object ])(operator , x , y )
139
+
140
+
186
141
@dispatch
187
142
def symbolic_operator (operator : str , x : numpy .ndarray , y : jax .numpy .ndarray ):
188
143
return numpy .vectorize (symbolic_operator , otypes = [object ])(operator , x , y )
189
- # XXX
190
144
191
145
192
146
@dispatch
@@ -199,59 +153,6 @@ def symbolic_operator(operator: str, x: list):
199
153
return symbolic_operator (operator , numpy .array (x ))
200
154
201
155
202
- # TODO: remove infix_operator?
203
- """
204
- @dispatch
205
- def symbolic_infix_operator(operator: str, a: str, b: str) -> str:
206
- return f'{a} {operator} {b}'.replace('\' ', '')
207
-
208
-
209
- @dispatch
210
- def symbolic_infix_operator(operator: str, a: numpy.ndarray, b: numpy.ndarray):
211
- return numpy.vectorize(symbolic_infix_operator, otypes=[object])(operator, a, b)
212
-
213
-
214
- @dispatch
215
- def symbolic_infix_operator(operator: str, a: list, b: numpy.ndarray):
216
- return symbolic_infix_operator(operator, numpy.array(a), b)
217
-
218
-
219
- @dispatch
220
- def symbolic_infix_operator(operator: str, a: numpy.ndarray, b: list):
221
- return symbolic_infix_operator(operator, a, numpy.array(b))
222
-
223
-
224
- @dispatch
225
- def symbolic_infix_operator(operator: str, a: str, b: int):
226
- return symbolic_infix_operator(operator, a, str(b))
227
-
228
-
229
- @dispatch
230
- def symbolic_infix_operator(operator: str, a: numpy.ndarray, b: float):
231
- return symbolic_infix_operator(operator, a, str(b))
232
-
233
-
234
- @dispatch
235
- def symbolic_infix_operator(operator: str, a: str, b: float):
236
- return symbolic_infix_operator(operator, a, str(b))
237
-
238
-
239
- @dispatch
240
- def symbolic_infix_operator(operator: str, a: numpy.ndarray, b: jax.numpy.ndarray):
241
- return symbolic_infix_operator(operator, a, numpy.array(b))
242
-
243
-
244
- # XXXX
245
- @dispatch
246
- def symbolic_infix_operator(operator: str, a: list, b: list):
247
- return symbolic_infix_operator(operator, numpy.array(a), numpy.array(b))
248
- # XXXX
249
-
250
- @dispatch
251
- def symbolic_infix_operator(operator: str, a: bool, b: str):
252
- return symbolic_infix_operator(operator, str(a), b)
253
- """
254
-
255
156
def all_concrete_values (data ):
256
157
if isinstance (data , str ):
257
158
return False
@@ -308,7 +209,6 @@ def symbolic_gt(*args, **kwargs):
308
209
return symbolic_operator ('lax_reference.gt' , * args , ** kwargs )
309
210
310
211
311
-
312
212
def symbolic_abs (* args , ** kwargs ):
313
213
if all_concrete_values ([* args ]):
314
214
return lax_reference .abs (* args , ** kwargs )
@@ -359,7 +259,6 @@ def symbolic_min(*args, **kwargs):
359
259
return symbolic_operator ('numpy.minimum' , * args , ** kwargs )
360
260
361
261
362
-
363
262
def symbolic_select_n (* args , ** kwargs ):
364
263
'''
365
264
Important comment from lax.py
@@ -375,14 +274,9 @@ def symbolic_select_n(*args, **kwargs):
375
274
else :
376
275
# swap order of on_true and on_false
377
276
# TODO: need a more general solution to unquoting symbolic strings
378
- #return f'numpy.where({repr(pred)}, {repr(on_false)}, {repr(on_true)})'.replace('\'', '')
379
- #return f'numpy.where({repr(pred)}, {repr(on_false)}, {repr(on_true)})'
380
277
evaluable_pred = symbolic_representation (pred )
381
- #print(f'evaluable_pred: {evaluable_pred}')
382
278
evaluable_on_true = symbolic_representation (on_true )
383
- #print(f'evaluable_on_true: {evaluable_on_true}')
384
279
evaluable_on_false = symbolic_representation (on_false )
385
- #print(f'evaluable_on_false: {evaluable_on_false}')
386
280
return f'lax_reference.select({ evaluable_pred } , { evaluable_on_false } , { evaluable_on_true } )'
387
281
388
282
@@ -391,7 +285,6 @@ def symbolic_and(*args, **kwargs):
391
285
return numpy .logical_and (* args , ** kwargs )
392
286
else :
393
287
return symbolic_operator ('numpy.logical_and' , * args , ** kwargs )
394
-
395
288
396
289
397
290
def symbolic_or (* args , ** kwargs ):
@@ -417,7 +310,8 @@ def symbolic_sum(*args, **kwargs):
417
310
418
311
def symbolic_broadcast_in_dim (* args , ** kwargs ):
419
312
# broadcast_in_dim requires numpy arrays not lists
420
- args = tuple ([numpy .array (arg ) if isinstance (arg , list ) else arg for arg in args ])
313
+ args = tuple ([numpy .array (arg ) if isinstance (
314
+ arg , list ) else arg for arg in args ])
421
315
return lax_reference .broadcast_in_dim (* args , ** kwargs )
422
316
423
317
@@ -435,7 +329,9 @@ def symbolic_convert_element_type(*args, **kwargs):
435
329
# If so, we can use the lax reference implementation
436
330
return lax_reference .convert_element_type (* args , dtype = kwargs ['new_dtype' ])
437
331
else :
438
- # Otherwise, we use the symbolic implementation
332
+ # Otherwise, we nop
333
+ def convert_element_type (x , dtype ):
334
+ return x
439
335
return convert_element_type (* args , dtype = kwargs ['new_dtype' ])
440
336
441
337
0 commit comments