Skip to content

Commit

Permalink
Refactor scattering za grid (#926)
Browse files Browse the repository at this point in the history
Make scattering za grid accessible in Python interface.
Remove exposure of `ZenithAngleGrid` and refactored the variant
alternatives.
ZenithAngleGrids support the numpy interface now.

@simonpf We also removed the inheritance from Vector as the interface
was not used anywhere anyway and it simplifies the nanobind interface.
Each grid contains a member `angles` instead now.
  • Loading branch information
riclarsson authored Feb 6, 2025
2 parents 234b5be + 6790b0a commit 5c5f3f1
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 91 deletions.
4 changes: 2 additions & 2 deletions src/core/scattering/integration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ void LobattoQuadrature::calculate_nodes_and_weights() {
}

IrregularZenithAngleGrid::IrregularZenithAngleGrid(const Vector& zenith_angles)
: ZenithAngleGrid(zenith_angles),
: angles(zenith_angles),
weights_(zenith_angles.size()),
cos_theta_(zenith_angles),
type_(QuadratureType::Trapezoidal) {
Expand All @@ -132,7 +132,7 @@ IrregularZenithAngleGrid::IrregularZenithAngleGrid(const Vector& zenith_angles)
cos_theta_.begin(),
[](Numeric lat) { return -1.0 * cos(Conversion::deg2rad(lat)); });
weights_ = 0.0;
Index n = static_cast<Index>(Vector::size());
Index n = static_cast<Index>(angles.size());
for (Index i = 0; i < n - 1; ++i) {
auto dx = 0.5 * (cos_theta_[i + 1] - cos_theta_[i]);
weights_[i] += dx;
Expand Down
57 changes: 13 additions & 44 deletions src/core/scattering/integration.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,45 +245,11 @@ class QuadratureProvider {
std::map<Index, Quadrature> quadratures_;
};

namespace detail {
/** Base class for zenith-angle grids.
*
* This class defines the basic interface for zenith-angle grids.
* It is used to store the cosine values of the grid points and,
* in addition to that, the integration weights of the corresponding
* quadrature.
*/
class ZenithAngleGrid : public Vector {
public:
ZenithAngleGrid() : Vector() {}
ZenithAngleGrid(const Vector& zenith_angles) : Vector(zenith_angles) {}
ZenithAngleGrid(Index i) : Vector(i) {}

virtual ~ZenithAngleGrid(){};

ZenithAngleGrid(const ZenithAngleGrid&) = default;
ZenithAngleGrid& operator=(const ZenithAngleGrid&) = default;
ZenithAngleGrid(ZenithAngleGrid&&) = default;
ZenithAngleGrid& operator=(ZenithAngleGrid&&) = default;

/// The cosine of the grid points.
virtual const Vector& get_angle_cosines() const = 0;
/// The grid points in radians.
virtual const Vector& get_angles() const { return *this; }
/// The integration weights.
virtual const Vector& get_weights() const = 0;

/// The type of quadrature.
virtual QuadratureType get_type() = 0;
};
}


using ZenithAngleGridPtr = std::shared_ptr<detail::ZenithAngleGrid>;
using ConstZenithAngleGridPtr = std::shared_ptr<const detail::ZenithAngleGrid>;

class IrregularZenithAngleGrid : public detail::ZenithAngleGrid {
class IrregularZenithAngleGrid {
public:
Vector angles;

IrregularZenithAngleGrid() = default;
/** Create new zenith-angle grid.
* @param zenith_angles Vector containing the zenith-angle grid points in radians.
Expand All @@ -307,8 +273,10 @@ class IrregularZenithAngleGrid : public detail::ZenithAngleGrid {
};

template <typename Quadrature>
class QuadratureZenithAngleGrid : public detail::ZenithAngleGrid {
class QuadratureZenithAngleGrid {
public:
Vector angles;

/** Create new quadrature zenith-angle grid with given number of points.
*
* Creates a zenith-angle grid using the nodes and weights of the given quadrature
Expand All @@ -318,12 +286,12 @@ class QuadratureZenithAngleGrid : public detail::ZenithAngleGrid {
* weights of the quadrature.
* @param degree The number of points of the quadrature.
*/
QuadratureZenithAngleGrid() : detail::ZenithAngleGrid() {}
QuadratureZenithAngleGrid() : angles() {}
QuadratureZenithAngleGrid(const QuadratureZenithAngleGrid&) = default;
QuadratureZenithAngleGrid(Index n_points)
: detail::ZenithAngleGrid(n_points), quadrature_(n_points) {
: angles(n_points), quadrature_(n_points) {
auto nodes = quadrature_.get_nodes();
std::transform(nodes.begin(), nodes.end(), begin(), [](Numeric x) {
std::transform(nodes.begin(), nodes.end(), angles.begin(), [](Numeric x) {
return Conversion::rad2deg(acos(-1.0 * x));
});
}
Expand Down Expand Up @@ -360,21 +328,22 @@ static QuadratureProvider<FejerQuadrature> quadratures =

using ZenithAngleGrid = std::variant<
IrregularZenithAngleGrid,
DoubleGaussGrid,
GaussLegendreGrid,
LobattoGrid,
FejerGrid
>;

inline Index grid_size(const ZenithAngleGrid &grid) {
return std::visit([](const auto &grd) { return grd.size(); }, grid);
return std::visit([](const auto &grd) { return grd.angles.size(); }, grid);
}

inline StridedVectorView grid_vector(ZenithAngleGrid &grid) {
return std::visit([](auto &grd) { return static_cast<StridedVectorView>(grd); }, grid);
return std::visit([](auto &grd) { return static_cast<StridedVectorView>(grd.angles); }, grid);
}

inline StridedConstVectorView grid_vector(const ZenithAngleGrid &grid) {
return std::visit([](const auto &grd) { return static_cast<StridedConstVectorView>(grd); }, grid);
return std::visit([](const auto &grd) { return static_cast<StridedConstVectorView>(grd.angles); }, grid);
}


Expand Down
6 changes: 3 additions & 3 deletions src/core/scattering/phase_matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,11 +327,11 @@ inline RegridWeights calc_regrid_weights(
}
if ((za_scat_grid) && (new_grids.za_scat_grid)) {
res.za_scat_grid_weights = ArrayOfGridPos(std::visit(
[](const auto &grd) { return grd.size(); }, *new_grids.za_scat_grid));
[](const auto &grd) { return grd.angles.size(); }, *new_grids.za_scat_grid));
gridpos(res.za_scat_grid_weights,
std::visit([](const auto &grd) { return static_cast<Vector>(grd); },
std::visit([](const auto &grd) { return static_cast<Vector>(grd.angles); },
*za_scat_grid),
std::visit([](const auto &grd) { return static_cast<Vector>(grd); },
std::visit([](const auto &grd) { return static_cast<Vector>(grd.angles); },
*new_grids.za_scat_grid),
1e99);
}
Expand Down
2 changes: 1 addition & 1 deletion src/core/scattering/sht.cc
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ ArrayOfIndex SHT::get_m_indices() {
FejerGrid SHT::get_zenith_angle_grid(Index n_za, bool radians) {
auto result = FejerGrid(n_za);
if (radians) {
result *= Conversion::deg2rad(1.0);
result.angles *= Conversion::deg2rad(1.0);
}
return result;
};
Expand Down
4 changes: 2 additions & 2 deletions src/core/scattering/xml_io_scattering.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ void xml_write_to_stream(std::ostream &os_xml,
ArtsXMLTag open_tag, close_tag;
open_tag.set_name("IrregularZenithAngleGrid");
open_tag.write_to_stream(os_xml);
xml_write_to_stream(os_xml, static_cast<const Vector&>(grid), pbofs, name);
xml_write_to_stream(os_xml, grid.angles, pbofs, name);
close_tag.set_name("/IrregularZenithAngleGrid");
close_tag.write_to_stream(os_xml);
os_xml << '\n';
Expand All @@ -56,7 +56,7 @@ void xml_read_from_stream(std::istream &is_xml,

tag.read_from_stream(is_xml);
tag.check_name("IrregularZenithAngleGrid");
xml_read_from_stream(is_xml, static_cast<Vector&>(grid), pbifs);
xml_read_from_stream(is_xml, grid.angles, pbifs);
}


Expand Down
90 changes: 54 additions & 36 deletions src/python_interface/py_scattering_species.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include <python_interface.h>
#include <stdexcept>

#include "hpy_arts.h"
#include "hpy_numpy.h"
#include "py_macros.h"

NB_MAKE_OPAQUE(scattering::ZenithAngleGrid);

namespace Python {

Expand All @@ -25,13 +26,20 @@ auto bind_phase_matrix_data_tro_gridded(py::module_& m,
scattering::Representation::Gridded>;

py::class_<PMD, matpack::data_t<Scalar, 4>> s(m, class_name.c_str());
s.def(py::init<std::shared_ptr<const Vector>,
std::shared_ptr<const Vector>,
std::shared_ptr<const scattering::ZenithAngleGrid>>(),
py::arg("t_grid"),
py::arg("f_grid"),
py::arg("za_scat_grid"))

s.def(
"__init__",
[](PMD* self,
std::shared_ptr<const Vector> t_grid,
std::shared_ptr<const Vector> f_grid,
scattering::ZenithAngleGrid za_grid) {
new (self) PMD{
t_grid,
f_grid,
std::make_shared<const scattering::ZenithAngleGrid>(std::move(za_grid))};
},
py::arg("t_grid"),
py::arg("f_grid"),
py::arg("za_scat_grid"))
// Bind methods, such as `extract_backscatter_matrix` and `extract_forwardscatter_matrix`
.def("extract_backscatter_matrix",
&PMD::extract_backscatter_matrix,
Expand All @@ -42,9 +50,13 @@ auto bind_phase_matrix_data_tro_gridded(py::module_& m,

.def("get_t_grid", &PMD::get_t_grid, "Get temperature grid")
.def("get_f_grid", &PMD::get_f_grid, "Get frequency grid")
.def("get_za_scat_grid",
&PMD::get_za_scat_grid,
"Get scattering zenith angle grid")
.def(
"get_za_scat_grid",
[](const PMD& pmd) {
if (pmd.get_za_scat_grid()) return *(pmd.get_za_scat_grid());
throw std::runtime_error("PMD Zenith angle grid not initialized");
},
"Get scattering zenith angle grid")

.def(
"to_spectral",
Expand Down Expand Up @@ -373,41 +385,47 @@ void py_scattering_species(py::module_& m) try {
[](const HenyeyGreensteinScatterer& hg,
const AtmPoint& atm_point,
const Vector& f_grid,
std::shared_ptr<scattering::ZenithAngleGrid> za_grid) {
scattering::ZenithAngleGrid za_grid) {
return BulkScatteringPropertiesTROGridded{
hg.get_bulk_scattering_properties_tro_gridded(
atm_point, f_grid, za_grid)};
atm_point, f_grid, std::make_shared<scattering::ZenithAngleGrid>(std::move(za_grid)))};
},
"atm_point"_a,
"f_grid"_a,
"za_grid"_a,
"Get bulk scattering properties")
.doc() = "Henyey-Greenstein scatterer";

py::class_<scattering::IrregularZenithAngleGrid>(m,
"IrregularZenithAngleGrid")
.def(py::init<Vector>())
py::class_<scattering::IrregularZenithAngleGrid> irr_grid(m,
"IrregularZenithAngleGrid");
irr_grid.def(py::init<Vector>())
.def_rw("value", &scattering::IrregularZenithAngleGrid::angles, "Zenith angle grid")
.doc() = "Irregular zenith angle grid";
py::class_<scattering::GaussLegendreGrid>(m, "GaussLegendreGrid")
.def(py::init<Index>())
common_ndarray(irr_grid);

py::class_<scattering::GaussLegendreGrid> gauss_grid(m, "GaussLegendreGrid");
gauss_grid.def(py::init<Index>())
.def_rw("value", &scattering::GaussLegendreGrid::angles, "Zenith angle grid for Legendre calculations")
.doc() = "Gaussian Legendre grid";
py::class_<scattering::DoubleGaussGrid>(m, "DoubleGaussGrid")
.def(py::init<Index>())
common_ndarray(gauss_grid);

py::class_<scattering::DoubleGaussGrid> double_gauss_grid(m, "DoubleGaussGrid");
double_gauss_grid.def(py::init<Index>())
.def_rw("value", &scattering::DoubleGaussGrid::angles, "Zenith angle grid for Double Gauss calculations")
.doc() = "Double Gaussian grid";
py::class_<scattering::LobattoGrid>(m, "LobattoGrid")
.def(py::init<Index>())
common_ndarray(double_gauss_grid);

py::class_<scattering::LobattoGrid> lobatto_grid(m, "LobattoGrid");
lobatto_grid.def(py::init<Index>())
.def_rw("value", &scattering::LobattoGrid::angles, "Zenith angle grid for Lobatto calculations")
.doc() = "Lobatto grid";
py::class_<scattering::FejerGrid>(m, "FejerGrid")
.def(py::init<Index>())
.doc() = "Fejer grid";
common_ndarray(lobatto_grid);

py::class_<scattering::ZenithAngleGrid>(m, "ZenithAngleGrid")
.def(py::init<scattering::IrregularZenithAngleGrid>())
.def(py::init<scattering::GaussLegendreGrid>())
.def(py::init<scattering::DoubleGaussGrid>())
.def(py::init<scattering::LobattoGrid>())
.def(py::init<scattering::FejerGrid>())
.doc() = "Zenith angle grid";
py::class_<scattering::FejerGrid> fejer_grid(m, "FejerGrid");
fejer_grid.def(py::init<Index>())
.def_rw("value", &scattering::FejerGrid::angles, "Zenith angle grid for Fejer calculations")
.doc() = "Fejer grid";
common_ndarray(fejer_grid);

py::class_<ArrayOfScatteringSpecies> aoss(m, "ArrayOfScatteringSpecies");
aoss.def(py::init<>())
Expand All @@ -431,10 +449,10 @@ void py_scattering_species(py::module_& m) try {
[](const ArrayOfScatteringSpecies& aoss,
const AtmPoint& atm_point,
const Vector& f_grid,
std::shared_ptr<scattering::ZenithAngleGrid> za_grid) {
scattering::ZenithAngleGrid za_grid) {
return BulkScatteringPropertiesTROGridded{
aoss.get_bulk_scattering_properties_tro_gridded(
atm_point, f_grid, za_grid)};
atm_point, f_grid, std::make_shared<scattering::ZenithAngleGrid>(std::move(za_grid)))};
},
"atm_point"_a,
"f_grid"_a,
Expand All @@ -447,13 +465,13 @@ void py_scattering_species(py::module_& m) try {
const Vector& f_grid,
const Vector& za_inc_grid,
const Vector& delta_aa_grid,
std::shared_ptr<scattering::ZenithAngleGrid> za_scat_grid) {
scattering::ZenithAngleGrid za_scat_grid) {
return BulkScatteringPropertiesAROGridded{
aoss.get_bulk_scattering_properties_aro_gridded(atm_point,
f_grid,
za_inc_grid,
delta_aa_grid,
za_scat_grid)};
std::make_shared<scattering::ZenithAngleGrid>(std::move(za_scat_grid)))};
},
"atm_point"_a,
"f_grid"_a,
Expand Down
2 changes: 1 addition & 1 deletion src/tests/scattering/test_phase_matrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ bool test_backscatter_matrix_regrid_tro() {
auto sht = sht::provider.get_instance(1, 32);
auto t_grid = std::make_shared<Vector>(Vector({210.0, 250.0, 270.0}));
auto f_grid = std::make_shared<Vector>(Vector({1e9, 10e9, 100e9}));
auto za_scat_grid = std::make_shared<ZenithAngleGrid>(IrregularZenithAngleGrid(sht->get_zenith_angle_grid()));
auto za_scat_grid = std::make_shared<ZenithAngleGrid>(IrregularZenithAngleGrid(sht->get_zenith_angle_grid().angles));
auto phase_matrix = make_phase_matrix(t_grid, f_grid, za_scat_grid);
auto backscatter_matrix = phase_matrix.extract_backscatter_matrix();

Expand Down
4 changes: 2 additions & 2 deletions src/tests/scattering/test_sht.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,12 @@ bool test_grids() try {
auto sht = sht::provider.get_instance(64, 64);

auto lat_grid = sht->get_zenith_angle_grid();
Numeric max_angle = max<Vector>(lat_grid);
Numeric max_angle = max<Vector>(lat_grid.angles);
if (max_angle < 2.0 * scattering::sht::pi_v<Numeric>) {
return false;
}
lat_grid = sht->get_zenith_angle_grid(true);
max_angle = max<Vector>(lat_grid);
max_angle = max<Vector>(lat_grid.angles);
if (max_angle > 2.0 * scattering::sht::pi_v<Numeric>) {
return false;
}
Expand Down

0 comments on commit 5c5f3f1

Please sign in to comment.