diff --git a/src/ast.c b/src/ast.c index 4799810c41550..87a6fd63d6ccf 100644 --- a/src/ast.c +++ b/src/ast.c @@ -55,6 +55,7 @@ STATIC_INLINE void jl_ast_context_list_delete(jl_ast_context_list_t *node) typedef struct _jl_ast_context_t { fl_context_t fl; fltype_t *jvtype; + fltype_t *jvscalartype; value_t true_sym; value_t false_sym; @@ -215,6 +216,7 @@ static void jl_init_ast_ctx(jl_ast_context_t *ast_ctx) fl_applyn(fl_ctx, 0, symbol_value(symbol(fl_ctx, "__init_globals"))); jl_ast_ctx(fl_ctx)->jvtype = define_opaque_type(fl_ctx->jl_sym, sizeof(void*), NULL, NULL); + jl_ast_ctx(fl_ctx)->jvscalartype = define_opaque_type(fl_ctx->jl_scalar_sym, sizeof(void*), NULL, NULL); assign_global_builtins(fl_ctx, julia_flisp_ast_ext); jl_ast_ctx(fl_ctx)->true_sym = symbol(fl_ctx, "true"); jl_ast_ctx(fl_ctx)->false_sym = symbol(fl_ctx, "false"); @@ -506,8 +508,12 @@ static jl_value_t *scm_to_julia_(fl_context_t *fl_ctx, value_t e, int eo) if (iscprim(e) && cp_class((cprim_t*)ptr(e)) == fl_ctx->wchartype) { return jl_box32(jl_char_type, *(int32_t*)cp_data((cprim_t*)ptr(e))); } - if (iscvalue(e) && cv_class((cvalue_t*)ptr(e)) == jl_ast_ctx(fl_ctx)->jvtype) { - return *(jl_value_t**)cv_data((cvalue_t*)ptr(e)); + if (iscvalue(e)) { + fltype_t *cls = cv_class((cvalue_t*)ptr(e)); + jl_ast_context_t *ctx = jl_ast_ctx(fl_ctx); + if (cls == ctx->jvtype || cls == ctx->jvscalartype) { + return *(jl_value_t**)cv_data((cvalue_t*)ptr(e)); + } } jl_error("malformed tree"); @@ -594,6 +600,14 @@ static value_t julia_to_scm_(fl_context_t *fl_ctx, jl_value_t *v) return julia_to_list2(fl_ctx, (jl_value_t*)newvar_sym, jl_fieldref(v,0)); if (jl_is_long(v) && fits_fixnum(jl_unbox_long(v))) return fixnum(jl_unbox_long(v)); + if (jl_subtype(v,(jl_value_t*)jl_number_type,1)) { + // mark scalars for compress-fuse optimization of broadcast fusion. + // TODO: once #16966 is fixed, include strings (etc.) here; + // (It's okay if we miss some types since this is just an optimization.) + value_t opaque = cvalue(fl_ctx, jl_ast_ctx(fl_ctx)->jvscalartype, sizeof(void*)); + *(jl_value_t**)cv_data((cvalue_t*)ptr(opaque)) = v; + return opaque; + } if (jl_is_ssavalue(v)) jl_error("SSAValue objects should not occur in an AST"); if (jl_is_slot(v)) diff --git a/src/ast.scm b/src/ast.scm index fc01cb05d35ff..54ece0578a8e4 100644 --- a/src/ast.scm +++ b/src/ast.scm @@ -9,7 +9,7 @@ ((string? e) (print-to-string e)) ((eq? e #t) "true") ((eq? e #f) "false") - ((eq? (typeof e) 'julia_value) + ((or (eq? (typeof e) 'julia_value) (eq? (typeof e) 'julia_scalar_value)) (let ((s (string e))) (if (string.find s "#jl_sym = symbol(fl_ctx, "julia_value"); + fl_ctx->jl_scalar_sym = symbol(fl_ctx, "julia_scalar_value"); fl_ctx->the_empty_vector = tagptr(alloc_words(fl_ctx, 1), TAG_VECTOR); vector_setsize(fl_ctx->the_empty_vector, 0); diff --git a/src/flisp/flisp.h b/src/flisp/flisp.h index fea46e570ea5e..dfc1fe9cfa5be 100644 --- a/src/flisp/flisp.h +++ b/src/flisp/flisp.h @@ -476,6 +476,7 @@ struct _fl_context_t { value_t apply_func, apply_v, apply_e; value_t jl_sym; + value_t jl_scalar_sym; // persistent buffer (avoid repeated malloc/free) // for julia_extensions.c: normalize size_t jlbuflen; diff --git a/src/flisp/print.c b/src/flisp/print.c index d12cbf0f37338..1f5caafb31992 100644 --- a/src/flisp/print.c +++ b/src/flisp/print.c @@ -657,7 +657,7 @@ static void cvalue_printdata(fl_context_t *fl_ctx, ios_t *f, void *data, #endif init = 1; } - if (jl_static_print != NULL && fl_ctx->jl_sym == type) { + if (jl_static_print != NULL && (fl_ctx->jl_sym == type || fl_ctx->jl_scalar_sym == type)) { fl_ctx->HPOS += ios_printf(f, "#HPOS += jl_static_print(f, *(void**)data); fl_ctx->HPOS += ios_printf(f, ">"); diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index 0d14e75fd6327..b651c63299b9d 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -1679,7 +1679,8 @@ new-fargs new-args (cons (cons (cadr farg) (cadr varfarg)) renames) varfarg vararg) (error "multiple splatted args cannot be fused into a single broadcast")))) - ((number? arg) ; inline numeric literals + ((or (number? arg) ; inline numeric literals + (eq? (typeof arg) 'julia_scalar_value)) ; & Julia scalars (cf (cdr old-fargs) (cdr old-args) new-fargs new-args (cons (cons farg arg) renames) diff --git a/test/broadcast.jl b/test/broadcast.jl index dfbbeed1c8950..51ceaf0d3c611 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -288,6 +288,8 @@ let identity = error, x = [1,2,3] x .= 1 # make sure it goes to broadcast!(Base.identity, ...), not identity @test x == [1,1,1] end +# See issue #18176: +@test ((a,b,c) -> a+b+c).(1.0:2, 3, 4) == ((a,b,c) -> a+b+c).(1.0:2, 3.0, 4.0) == [8,9] # PR 16988 @test Base.promote_op(+, Bool) === Int