diff --git a/ext/stringio/stringio.c b/ext/stringio/stringio.c index 43ce0d6..c334647 100644 --- a/ext/stringio/stringio.c +++ b/ext/stringio/stringio.c @@ -1143,38 +1143,57 @@ struct getline_arg { }; static struct getline_arg * -prepare_getline_args(struct getline_arg *arg, int argc, VALUE *argv) +prepare_getline_args(struct StringIO *ptr, struct getline_arg *arg, int argc, VALUE *argv) { - VALUE str, lim, opts; + VALUE rs, lim, opts; long limit = -1; int respect_chomp; - argc = rb_scan_args(argc, argv, "02:", &str, &lim, &opts); - respect_chomp = argc == 0 || !NIL_P(str); + argc = rb_scan_args(argc, argv, "02:", &rs, &lim, &opts); + respect_chomp = argc == 0 || !NIL_P(rs); switch (argc) { case 0: - str = rb_rs; + rs = rb_rs; break; case 1: - if (!NIL_P(str) && !RB_TYPE_P(str, T_STRING)) { - VALUE tmp = rb_check_string_type(str); + if (!NIL_P(rs) && !RB_TYPE_P(rs, T_STRING)) { + VALUE tmp = rb_check_string_type(rs); if (NIL_P(tmp)) { - limit = NUM2LONG(str); - str = rb_rs; + limit = NUM2LONG(rs); + rs = rb_rs; } else { - str = tmp; + rs = tmp; } } break; case 2: - if (!NIL_P(str)) StringValue(str); + if (!NIL_P(rs)) StringValue(rs); if (!NIL_P(lim)) limit = NUM2LONG(lim); break; } - arg->rs = str; + if (!NIL_P(rs)) { + rb_encoding *enc_rs, *enc_io; + enc_rs = rb_enc_get(rs); + enc_io = get_enc(ptr); + if (enc_rs != enc_io && + (rb_enc_str_coderange(rs) != ENC_CODERANGE_7BIT || + (RSTRING_LEN(rs) > 0 && !rb_enc_asciicompat(enc_io)))) { + if (rs == rb_rs) { + rs = rb_enc_str_new(0, 0, enc_io); + rb_str_buf_cat_ascii(rs, "\n"); + rs = rs; + } + else { + rb_raise(rb_eArgError, "encoding mismatch: %s IO with %s RS", + rb_enc_name(enc_io), + rb_enc_name(enc_rs)); + } + } + } + arg->rs = rs; arg->limit = limit; arg->chomp = 0; if (!NIL_P(opts)) { @@ -1302,15 +1321,15 @@ strio_getline(struct getline_arg *arg, struct StringIO *ptr) static VALUE strio_gets(int argc, VALUE *argv, VALUE self) { + struct StringIO *ptr = readable(self); struct getline_arg arg; VALUE str; - if (prepare_getline_args(&arg, argc, argv)->limit == 0) { - struct StringIO *ptr = readable(self); + if (prepare_getline_args(ptr, &arg, argc, argv)->limit == 0) { return rb_enc_str_new(0, 0, get_enc(ptr)); } - str = strio_getline(&arg, readable(self)); + str = strio_getline(&arg, ptr); rb_lastline_set(str); return str; } @@ -1347,16 +1366,16 @@ static VALUE strio_each(int argc, VALUE *argv, VALUE self) { VALUE line; + struct StringIO *ptr = readable(self); struct getline_arg arg; - StringIO(self); RETURN_ENUMERATOR(self, argc, argv); - if (prepare_getline_args(&arg, argc, argv)->limit == 0) { + if (prepare_getline_args(ptr, &arg, argc, argv)->limit == 0) { rb_raise(rb_eArgError, "invalid limit: 0 for each_line"); } - while (!NIL_P(line = strio_getline(&arg, readable(self)))) { + while (!NIL_P(line = strio_getline(&arg, ptr))) { rb_yield(line); } return self; @@ -1374,15 +1393,15 @@ static VALUE strio_readlines(int argc, VALUE *argv, VALUE self) { VALUE ary, line; + struct StringIO *ptr = readable(self); struct getline_arg arg; - StringIO(self); - ary = rb_ary_new(); - if (prepare_getline_args(&arg, argc, argv)->limit == 0) { + if (prepare_getline_args(ptr, &arg, argc, argv)->limit == 0) { rb_raise(rb_eArgError, "invalid limit: 0 for readlines"); } - while (!NIL_P(line = strio_getline(&arg, readable(self)))) { + ary = rb_ary_new(); + while (!NIL_P(line = strio_getline(&arg, ptr))) { rb_ary_push(ary, line); } return ary; diff --git a/test/stringio/test_stringio.rb b/test/stringio/test_stringio.rb index cb82841..216b06d 100644 --- a/test/stringio/test_stringio.rb +++ b/test/stringio/test_stringio.rb @@ -88,6 +88,14 @@ def test_gets assert_string("", Encoding::UTF_8, StringIO.new("foo").gets(0)) end + def test_gets_utf_16 + stringio = StringIO.new("line1\nline2\nline3\n".encode("utf-16le")) + assert_equal("line1\n".encode("utf-16le"), stringio.gets) + assert_equal("line2\n".encode("utf-16le"), stringio.gets) + assert_equal("line3\n".encode("utf-16le"), stringio.gets) + assert_nil(stringio.gets) + end + def test_gets_chomp assert_equal(nil, StringIO.new("").gets(chomp: true)) assert_equal("", StringIO.new("\n").gets(chomp: true))