Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inline native Julia scalars encountered in f.(args...) AST #18202

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/ast.c
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ STATIC_INLINE void jl_ast_context_list_delete(jl_ast_context_list_t *node)
typedef struct _jl_ast_context_t {
fl_context_t fl;
fltype_t *jvtype;
fltype_t *jvscalartype;

value_t true_sym;
value_t false_sym;
Expand Down Expand Up @@ -215,6 +216,7 @@ static void jl_init_ast_ctx(jl_ast_context_t *ast_ctx)
fl_applyn(fl_ctx, 0, symbol_value(symbol(fl_ctx, "__init_globals")));

jl_ast_ctx(fl_ctx)->jvtype = define_opaque_type(fl_ctx->jl_sym, sizeof(void*), NULL, NULL);
jl_ast_ctx(fl_ctx)->jvscalartype = define_opaque_type(fl_ctx->jl_scalar_sym, sizeof(void*), NULL, NULL);
assign_global_builtins(fl_ctx, julia_flisp_ast_ext);
jl_ast_ctx(fl_ctx)->true_sym = symbol(fl_ctx, "true");
jl_ast_ctx(fl_ctx)->false_sym = symbol(fl_ctx, "false");
Expand Down Expand Up @@ -506,8 +508,12 @@ static jl_value_t *scm_to_julia_(fl_context_t *fl_ctx, value_t e, int eo)
if (iscprim(e) && cp_class((cprim_t*)ptr(e)) == fl_ctx->wchartype) {
return jl_box32(jl_char_type, *(int32_t*)cp_data((cprim_t*)ptr(e)));
}
if (iscvalue(e) && cv_class((cvalue_t*)ptr(e)) == jl_ast_ctx(fl_ctx)->jvtype) {
return *(jl_value_t**)cv_data((cvalue_t*)ptr(e));
if (iscvalue(e)) {
fltype_t *cls = cv_class((cvalue_t*)ptr(e));
jl_ast_context_t *ctx = jl_ast_ctx(fl_ctx);
if (cls == ctx->jvtype || cls == ctx->jvscalartype) {
return *(jl_value_t**)cv_data((cvalue_t*)ptr(e));
}
}
jl_error("malformed tree");

Expand Down Expand Up @@ -594,6 +600,14 @@ static value_t julia_to_scm_(fl_context_t *fl_ctx, jl_value_t *v)
return julia_to_list2(fl_ctx, (jl_value_t*)newvar_sym, jl_fieldref(v,0));
if (jl_is_long(v) && fits_fixnum(jl_unbox_long(v)))
return fixnum(jl_unbox_long(v));
if (jl_subtype(v,(jl_value_t*)jl_number_type,1)) {
// mark scalars for compress-fuse optimization of broadcast fusion.
// TODO: once #16966 is fixed, include strings (etc.) here;
// (It's okay if we miss some types since this is just an optimization.)
value_t opaque = cvalue(fl_ctx, jl_ast_ctx(fl_ctx)->jvscalartype, sizeof(void*));
*(jl_value_t**)cv_data((cvalue_t*)ptr(opaque)) = v;
return opaque;
}
if (jl_is_ssavalue(v))
jl_error("SSAValue objects should not occur in an AST");
if (jl_is_slot(v))
Expand Down
2 changes: 1 addition & 1 deletion src/ast.scm
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
((string? e) (print-to-string e))
((eq? e #t) "true")
((eq? e #f) "false")
((eq? (typeof e) 'julia_value)
((or (eq? (typeof e) 'julia_value) (eq? (typeof e) 'julia_scalar_value))
(let ((s (string e)))
(if (string.find s "#<julia: ")
;; successfully printed as a julia value
Expand Down
1 change: 1 addition & 0 deletions src/flisp/flisp.c
Original file line number Diff line number Diff line change
Expand Up @@ -2398,6 +2398,7 @@ static void lisp_init(fl_context_t *fl_ctx, size_t initial_heapsize)
#endif

fl_ctx->jl_sym = symbol(fl_ctx, "julia_value");
fl_ctx->jl_scalar_sym = symbol(fl_ctx, "julia_scalar_value");

fl_ctx->the_empty_vector = tagptr(alloc_words(fl_ctx, 1), TAG_VECTOR);
vector_setsize(fl_ctx->the_empty_vector, 0);
Expand Down
1 change: 1 addition & 0 deletions src/flisp/flisp.h
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,7 @@ struct _fl_context_t {
value_t apply_func, apply_v, apply_e;

value_t jl_sym;
value_t jl_scalar_sym;
// persistent buffer (avoid repeated malloc/free)
// for julia_extensions.c: normalize
size_t jlbuflen;
Expand Down
2 changes: 1 addition & 1 deletion src/flisp/print.c
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ static void cvalue_printdata(fl_context_t *fl_ctx, ios_t *f, void *data,
#endif
init = 1;
}
if (jl_static_print != NULL && fl_ctx->jl_sym == type) {
if (jl_static_print != NULL && (fl_ctx->jl_sym == type || fl_ctx->jl_scalar_sym == type)) {
fl_ctx->HPOS += ios_printf(f, "#<julia: ");
fl_ctx->HPOS += jl_static_print(f, *(void**)data);
fl_ctx->HPOS += ios_printf(f, ">");
Expand Down
3 changes: 2 additions & 1 deletion src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -1679,7 +1679,8 @@
new-fargs new-args (cons (cons (cadr farg) (cadr varfarg)) renames)
varfarg vararg)
(error "multiple splatted args cannot be fused into a single broadcast"))))
((number? arg) ; inline numeric literals
((or (number? arg) ; inline numeric literals
(eq? (typeof arg) 'julia_scalar_value)) ; & Julia scalars
(cf (cdr old-fargs) (cdr old-args)
new-fargs new-args
(cons (cons farg arg) renames)
Expand Down
2 changes: 2 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,8 @@ let identity = error, x = [1,2,3]
x .= 1 # make sure it goes to broadcast!(Base.identity, ...), not identity
@test x == [1,1,1]
end
# See issue #18176:
@test ((a,b,c) -> a+b+c).(1.0:2, 3, 4) == ((a,b,c) -> a+b+c).(1.0:2, 3.0, 4.0) == [8,9]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that this is also tested in #18200, instead of this, a test of the resulting expression after parsing would be more useful?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we could check the result of expand.


# PR 16988
@test Base.promote_op(+, Bool) === Int
Expand Down