8
8
9
9
# TODO: remove me?
10
10
def convert_element_type (x , dtype ):
11
+ return x
11
12
if dtype == numpy .int32 or dtype == numpy .int64 :
12
13
#dtype = 'int'
13
14
return x
@@ -199,6 +200,7 @@ def symbolic_operator(operator: str, x: list):
199
200
200
201
201
202
# TODO: remove infix_operator?
203
+ """
202
204
@dispatch
203
205
def symbolic_infix_operator(operator: str, a: str, b: str) -> str:
204
206
return f'{a} {operator} {b}'.replace('\' ', '')
@@ -248,7 +250,7 @@ def symbolic_infix_operator(operator: str, a: list, b: list):
248
250
@dispatch
249
251
def symbolic_infix_operator(operator: str, a: bool, b: str):
250
252
return symbolic_infix_operator(operator, str(a), b)
251
-
253
+ """
252
254
253
255
def all_concrete_values (data ):
254
256
if isinstance (data , str ):
@@ -275,39 +277,34 @@ def symbolic_eq(*args, **kwargs):
275
277
if all_concrete_values ([* args ]):
276
278
return lax_reference .eq (* args , ** kwargs )
277
279
else :
278
- #return '(' + symbolic_infix_operator('==', *args, **kwargs) + ')'
279
280
return symbolic_operator ('lax_reference.eq' , * args , ** kwargs )
280
281
281
282
282
283
def symbolic_ne (* args , ** kwargs ):
283
284
if all_concrete_values ([* args ]):
284
285
return lax_reference .ne (* args , ** kwargs )
285
286
else :
286
- #return '(' + symbolic_infix_operator('!=', *args, **kwargs) + ')'
287
287
return symbolic_operator ('lax_reference.ne' , * args , ** kwargs )
288
288
289
289
290
290
def symbolic_le (* args , ** kwargs ):
291
291
if all_concrete_values ([* args ]):
292
292
return lax_reference .le (* args , ** kwargs )
293
293
else :
294
- #return '(' + symbolic_infix_operator('<=', *args, **kwargs) + ')'
295
294
return symbolic_operator ('lax_reference.le' , * args , ** kwargs )
296
295
297
296
298
297
def symbolic_lt (* args , ** kwargs ):
299
298
if all_concrete_values ([* args ]):
300
299
return lax_reference .lt (* args , ** kwargs )
301
300
else :
302
- #return '(' + symbolic_infix_operator('<', *args, **kwargs) + ')'
303
301
return symbolic_operator ('lax_reference.lt' , * args , ** kwargs )
304
302
305
303
306
304
def symbolic_gt (* args , ** kwargs ):
307
305
if all_concrete_values ([* args ]):
308
306
return lax_reference .gt (* args , ** kwargs )
309
307
else :
310
- #return symbolic_infix_operator('>', *args, **kwargs)
311
308
return symbolic_operator ('lax_reference.gt' , * args , ** kwargs )
312
309
313
310
@@ -393,7 +390,6 @@ def symbolic_and(*args, **kwargs):
393
390
if all_concrete_values ([* args ]):
394
391
return numpy .logical_and (* args , ** kwargs )
395
392
else :
396
- #return symbolic_infix_operator('and', *args, **kwargs)
397
393
return symbolic_operator ('numpy.logical_and' , * args , ** kwargs )
398
394
399
395
@@ -402,33 +398,27 @@ def symbolic_or(*args, **kwargs):
402
398
if all_concrete_values ([* args ]):
403
399
return numpy .logical_or (* args , ** kwargs )
404
400
else :
405
- #return '(' + symbolic_infix_operator('or', *args, **kwargs) + ')'
406
401
return symbolic_operator ('numpy.logical_or' , * args , ** kwargs )
407
402
408
403
409
404
def symbolic_xor (* args , ** kwargs ):
410
405
if all_concrete_values ([* args ]):
411
406
return numpy .logical_xor (* args , ** kwargs )
412
407
else :
413
- #return symbolic_infix_operator('^', *args, **kwargs)
414
408
return symbolic_operator ('numpy.logical_xor' , * args , ** kwargs )
415
409
416
410
417
411
def symbolic_sum (* args , ** kwargs ):
418
412
if all_concrete_values ([* args ]):
419
413
return lax_reference .sum (* args , ** kwargs )
420
414
else :
421
- #return symbolic_infix_operator('+', *args, **kwargs)
422
415
return symbolic_operator ('lax_reference.sum' , * args , ** kwargs )
423
416
424
417
425
418
def symbolic_broadcast_in_dim (* args , ** kwargs ):
426
- assert len (args ) == 1
427
- arg = args [0 ]
428
- if isinstance (args , (list , tuple )):
429
- # reference implementation demands a numpy array
430
- arg = numpy .array (arg )
431
- return lax_reference .broadcast_in_dim (arg , ** kwargs )
419
+ # broadcast_in_dim requires numpy arrays not lists
420
+ args = tuple ([numpy .array (arg ) if isinstance (arg , list ) else arg for arg in args ])
421
+ return lax_reference .broadcast_in_dim (* args , ** kwargs )
432
422
433
423
434
424
def symbolic_reshape (* args , ** kwargs ):
0 commit comments