Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: only compare scalars in full_like #2857

Merged
merged 3 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/awkward/operations/ak_full_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext, ensure_same_backend
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._nplikes.typetracer import ensure_known_scalar
from awkward._nplikes.typetracer import is_unknown_scalar
from awkward._regularize import is_integer_like
from awkward.operations.ak_zeros_like import _ZEROS

__all__ = ("full_like",)
Expand Down Expand Up @@ -125,12 +126,20 @@ def action(layout, backend, **kwargs):
if layout.is_numpy:
original = nplike.asarray(layout.data)

if fill_value is _ZEROS or ensure_known_scalar(fill_value == 0, False):
if fill_value is _ZEROS or (
is_integer_like(fill_value)
and not is_unknown_scalar(fill_value)
and fill_value == 0
):
return ak.contents.NumpyArray(
nplike.zeros_like(original, dtype=dtype),
parameters=layout.parameters,
)
elif ensure_known_scalar(fill_value == 1, False):
elif (
is_integer_like(fill_value)
and not is_unknown_scalar(fill_value)
and fill_value == 1
):
return ak.contents.NumpyArray(
nplike.ones_like(original, dtype=dtype),
parameters=layout.parameters,
Expand Down
36 changes: 36 additions & 0 deletions tests/test_2857_full_like_scalar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import numpy as np
import pytest # noqa: F401

import awkward as ak


def test():
arr = ak.Array([{"x": 1}, {"x": 2}])
# Fill with
result = ak.full_like(arr, np.datetime64(20, "s"), dtype="<M8[s]")
assert result.layout.is_equal_to(
ak.contents.RecordArray(
[
ak.contents.NumpyArray(
np.array([20, 20], dtype=np.dtype("datetime64[s]"))
)
],
["x"],
)
)


def test_typetracer():
arr = ak.Array([{"x": 1}, {"x": 2}], backend="typetracer")
# Fill with
result = ak.full_like(arr, np.datetime64(20, "s"), dtype="<M8[s]")
assert result.layout.form == (
ak.forms.RecordForm(
[ak.forms.NumpyForm("datetime64[s]")],
["x"],
)
)