Skip to content

Commit

Permalink
Use PyTuple for more idiomatic PyO3 code. (#673)
Browse files Browse the repository at this point in the history
  • Loading branch information
obi1kenobi authored Sep 18, 2024
1 parent d1992bb commit 0a7b29b
Showing 1 changed file with 62 additions and 43 deletions.
105 changes: 62 additions & 43 deletions pytrustfall/src/shim.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::{collections::BTreeMap, sync::Arc};

use pyo3::{exceptions::PyStopIteration, prelude::*, types::PyIterator, wrap_pyfunction};
use pyo3::{
exceptions::PyStopIteration, prelude::*, types::PyIterator, types::PyTuple, wrap_pyfunction,
};
use trustfall_core::{
frontend::{error::FrontendError, parse},
interpreter::{
Expand Down Expand Up @@ -359,24 +361,27 @@ impl Iterator for PythonResolvePropertyIterator {

fn next(&mut self) -> Option<Self::Item> {
Python::with_gil(|py| {
match self.underlying.call_method_bound(py, "__next__", (), None) {
match self.underlying.call_method0(py, "__next__") {
Ok(output) => {
// value is a (context, property_value) tuple here
let context: Opaque = output
.call_method_bound(py, "__getitem__", (0i64,), None)
.unwrap()
.extract(py)
.unwrap();

// TODO: if this panics, we got an unrepresentable FieldValue,
// which should be a proper error
let value: FieldValue = output
.call_method_bound(py, "__getitem__", (1i64,), None)
.unwrap()
.extract(py)
.unwrap();

Some((context, value))
// `output` must be a (context, property_value) tuple here, or else we panic.
let tuple = output.downcast_bound(py).expect(
"resolve_property() did not yield a `(context, property_value)` tuple",
);

let tuple_size_error: &'static str =
"resolve_property() yielded a tuple that did not have exactly 2 elements";

let property_value: FieldValue =
tuple.get_borrowed_item(1).expect(tuple_size_error).extract().expect(
"resolve_property() tuple element at index 1 is not a property value",
);

let context: Opaque = tuple.get_borrowed_item(0)
.expect(tuple_size_error)
.extract()
.expect("resolve_property() tuple element at index 0 is not a context (Opaque) value");

Some((context, property_value))
}
Err(e) => {
if e.is_instance_of::<PyStopIteration>(py) {
Expand Down Expand Up @@ -407,21 +412,27 @@ impl Iterator for PythonResolveNeighborsIterator {

fn next(&mut self) -> Option<Self::Item> {
Python::with_gil(|py| {
match self.underlying.call_method_bound(py, "__next__", (), None) {
match self.underlying.call_method0(py, "__next__") {
Ok(output) => {
// value is a (context, neighbor_iterator) tuple here
let context: Opaque = output
.call_method_bound(py, "__getitem__", (0i64,), None)
.unwrap()
.extract(py)
.unwrap();
let neighbors_iterable =
output.call_method_bound(py, "__getitem__", (1i64,), None).unwrap();

// Allow returning iterables (e.g. []), not just iterators.
// Iterators return self when __iter__() is called.
// `output` must be a (context, neighbor_iterator) tuple here, or else we panic.
let tuple: &Bound<'_, PyTuple> = output.downcast_bound(py).expect(
"resolve_neighbors() did not yield a `(context, neighbor_iterator)` tuple",
);

let tuple_size_error: &'static str =
"resolve_neighbors() yielded a tuple that did not have exactly 2 elements";

let neighbors_iterable = tuple.get_borrowed_item(1).expect(tuple_size_error);

let context: Opaque = tuple.get_borrowed_item(0)
.expect(tuple_size_error)
.extract()
.expect("resolve_neighbors() tuple element at index 0 is not a context (Opaque) value");

// Support returning iterables (e.g. []), not just iterators.
// Iterators return self when `__iter__()` is called.
let neighbors_iter = make_iterator(
neighbors_iterable.bind(py),
&neighbors_iterable,
"resolve_neighbors() yielded tuple's second element",
);

Expand Down Expand Up @@ -458,19 +469,27 @@ impl Iterator for PythonResolveCoercionIterator {

fn next(&mut self) -> Option<Self::Item> {
Python::with_gil(|py| {
match self.underlying.call_method_bound(py, "__next__", (), None) {
match self.underlying.call_method0(py, "__next__") {
Ok(output) => {
// value is a (context, can_coerce) tuple here
let context: Opaque = output
.call_method_bound(py, "__getitem__", (0i64,), None)
.unwrap()
.extract(py)
.unwrap();
let can_coerce: bool = output
.call_method_bound(py, "__getitem__", (1i64,), None)
.unwrap()
.extract::<bool>(py)
.unwrap();
// `output` must be a (context, can_coerce) tuple here, or else we panic.
let tuple = output
.downcast_bound(py)
.expect("resolve_coercion() did not yield a `(context, can_coerce)` tuple");

let tuple_size_error: &'static str =
"resolve_coercion() yielded a tuple that did not have exactly 2 elements";

let can_coerce: bool = tuple
.get_borrowed_item(1)
.expect(tuple_size_error)
.extract()
.expect("resolve_coercion() tuple element at index 1 is not a bool");

let context: Opaque = tuple.get_borrowed_item(0)
.expect(tuple_size_error)
.extract()
.expect("resolve_coercion() tuple element at index 0 is not a context (Opaque) value");

Some((context, can_coerce))
}
Err(e) => {
Expand Down

0 comments on commit 0a7b29b

Please sign in to comment.