◐ Shell
clean mode source ↗

PyAnextAwaitable by youknowone · Pull Request #6427 · RustPython/RustPython

/// Awaitable wrapper for anext() builtin with default value. /// When StopAsyncIteration is raised, it converts it to StopIteration(default). #[pyclass(module = false, name = "anext_awaitable")] #[derive(Debug)] pub struct PyAnextAwaitable { wrapped: PyObjectRef, default_value: PyObjectRef, }
impl PyPayload for PyAnextAwaitable { #[inline] fn class(ctx: &Context) -> &'static Py<PyType> { ctx.types.anext_awaitable } }
#[pyclass(with(IterNext, Iterable))] impl PyAnextAwaitable { pub fn new(wrapped: PyObjectRef, default_value: PyObjectRef) -> Self { Self { wrapped, default_value, } }
#[pymethod(name = "__await__")] fn r#await(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> { zelf }
/// Get the awaitable iterator from wrapped object. // = anextawaitable_getiter. fn get_awaitable_iter(&self, vm: &VirtualMachine) -> PyResult { use crate::builtins::PyCoroutine; use crate::protocol::PyIter;
let wrapped = &self.wrapped;
// If wrapped is already an async_generator_asend, it's an iterator if wrapped.class().is(vm.ctx.types.async_generator_asend) || wrapped.class().is(vm.ctx.types.async_generator_athrow) { return Ok(wrapped.clone()); }
// _PyCoro_GetAwaitableIter equivalent let awaitable = if wrapped.class().is(vm.ctx.types.coroutine_type) { // Coroutine - get __await__ later wrapped.clone() } else { // Try to get __await__ method if let Some(await_method) = vm.get_method(wrapped.clone(), identifier!(vm, __await__)) { await_method?.call((), vm)? } else { return Err(vm.new_type_error(format!( "object {} can't be used in 'await' expression", wrapped.class().name() ))); } };
// If awaitable is a coroutine, get its __await__ if awaitable.class().is(vm.ctx.types.coroutine_type) { let coro_await = vm.call_method(&awaitable, "__await__", ())?; // Check that __await__ returned an iterator if !PyIter::check(&coro_await) { return Err(vm.new_type_error("__await__ returned a non-iterable")); } return Ok(coro_await); }
// Check the result is an iterator, not a coroutine if awaitable.downcast_ref::<PyCoroutine>().is_some() { return Err(vm.new_type_error("__await__() returned a coroutine")); }
// Check that the result is an iterator if !PyIter::check(&awaitable) { return Err(vm.new_type_error(format!( "__await__() returned non-iterator of type '{}'", awaitable.class().name() ))); }
Ok(awaitable) }
#[pymethod] fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult { let awaitable = self.get_awaitable_iter(vm)?; let result = vm.call_method(&awaitable, "send", (val,)); self.handle_result(result, vm) }
#[pymethod] fn throw( &self, exc_type: PyObjectRef, exc_val: OptionalArg, exc_tb: OptionalArg, vm: &VirtualMachine, ) -> PyResult { let awaitable = self.get_awaitable_iter(vm)?; let result = vm.call_method( &awaitable, "throw", ( exc_type, exc_val.unwrap_or_none(vm), exc_tb.unwrap_or_none(vm), ), ); self.handle_result(result, vm) }
#[pymethod] fn close(&self, vm: &VirtualMachine) -> PyResult<()> { if let Ok(awaitable) = self.get_awaitable_iter(vm) { let _ = vm.call_method(&awaitable, "close", ()); } Ok(()) }
/// Convert StopAsyncIteration to StopIteration(default_value) fn handle_result(&self, result: PyResult, vm: &VirtualMachine) -> PyResult { match result { Ok(value) => Ok(value), Err(exc) if exc.fast_isinstance(vm.ctx.exceptions.stop_async_iteration) => { Err(vm.new_stop_iteration(Some(self.default_value.clone()))) } Err(exc) => Err(exc), } } }
impl SelfIter for PyAnextAwaitable {} impl IterNext for PyAnextAwaitable { fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> { PyIterReturn::from_pyresult(zelf.send(vm.ctx.none(), vm), vm) } }