Skip to content

Commit 49483cb

Browse files
committed
more asyncgen impl
1 parent c11a72a commit 49483cb

5 files changed

Lines changed: 166 additions & 13 deletions

File tree

Lib/test/test_asyncgen.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -467,16 +467,12 @@ async def test_throw():
467467
result = self.loop.run_until_complete(test_throw())
468468
self.assertEqual(result, "completed")
469469

470-
# TODO: RUSTPYTHON, NameError: name 'anext' is not defined
471-
@unittest.expectedFailure
472470
def test_async_generator_anext(self):
473471
async def agen():
474472
yield 1
475473
yield 2
476474
self.check_async_iterator_anext(agen)
477475

478-
# TODO: RUSTPYTHON, NameError: name 'anext' is not defined
479-
@unittest.expectedFailure
480476
def test_python_async_iterator_anext(self):
481477
class MyAsyncIter:
482478
"""Asynchronously yield 1, then 2."""
@@ -492,8 +488,6 @@ async def __anext__(self):
492488
return self.yielded
493489
self.check_async_iterator_anext(MyAsyncIter)
494490

495-
# TODO: RUSTPYTHON, NameError: name 'anext' is not defined
496-
@unittest.expectedFailure
497491
def test_python_async_iterator_types_coroutine_anext(self):
498492
import types
499493
class MyAsyncIterWithTypesCoro:
@@ -549,8 +543,6 @@ async def gen():
549543
applied_twice = aiter(applied_once)
550544
self.assertIs(applied_once, applied_twice)
551545

552-
# TODO: RUSTPYTHON, NameError: name 'anext' is not defined
553-
@unittest.expectedFailure
554546
def test_anext_bad_args(self):
555547
async def gen():
556548
yield 1
@@ -571,7 +563,7 @@ async def call_with_kwarg():
571563
with self.assertRaises(TypeError):
572564
self.loop.run_until_complete(call_with_kwarg())
573565

574-
# TODO: RUSTPYTHON, NameError: name 'anext' is not defined
566+
# TODO: RUSTPYTHON, error message mismatch
575567
@unittest.expectedFailure
576568
def test_anext_bad_await(self):
577569
async def bad_awaitable():
@@ -642,7 +634,7 @@ async def do_test():
642634
result = self.loop.run_until_complete(do_test())
643635
self.assertEqual(result, "completed")
644636

645-
# TODO: RUSTPYTHON, NameError: name 'anext' is not defined
637+
# TODO: RUSTPYTHON, anext coroutine iteration issue
646638
@unittest.expectedFailure
647639
def test_anext_iter(self):
648640
@types.coroutine

crates/vm/src/builtins/asyncgenerator.rs

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,8 +424,151 @@ impl IterNext for PyAsyncGenAThrow {
424424
}
425425
}
426426

427+
/// Awaitable wrapper for anext() builtin with default value.
428+
/// When StopAsyncIteration is raised, it converts it to StopIteration(default).
429+
#[pyclass(module = false, name = "anext_awaitable")]
430+
#[derive(Debug)]
431+
pub struct PyAnextAwaitable {
432+
wrapped: PyObjectRef,
433+
default_value: PyObjectRef,
434+
}
435+
436+
impl PyPayload for PyAnextAwaitable {
437+
#[inline]
438+
fn class(ctx: &Context) -> &'static Py<PyType> {
439+
ctx.types.anext_awaitable
440+
}
441+
}
442+
443+
#[pyclass(with(IterNext, Iterable))]
444+
impl PyAnextAwaitable {
445+
pub fn new(wrapped: PyObjectRef, default_value: PyObjectRef) -> Self {
446+
Self {
447+
wrapped,
448+
default_value,
449+
}
450+
}
451+
452+
#[pymethod(name = "__await__")]
453+
fn r#await(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
454+
zelf
455+
}
456+
457+
/// Get the awaitable iterator from wrapped object.
458+
/// This is equivalent to CPython's anextawaitable_getiter.
459+
fn get_awaitable_iter(&self, vm: &VirtualMachine) -> PyResult {
460+
use crate::builtins::PyCoroutine;
461+
use crate::protocol::PyIter;
462+
463+
let wrapped = &self.wrapped;
464+
465+
// If wrapped is already an async_generator_asend, it's an iterator
466+
if wrapped.class().is(vm.ctx.types.async_generator_asend)
467+
|| wrapped.class().is(vm.ctx.types.async_generator_athrow)
468+
{
469+
return Ok(wrapped.clone());
470+
}
471+
472+
// _PyCoro_GetAwaitableIter equivalent
473+
let awaitable = if wrapped.class().is(vm.ctx.types.coroutine_type) {
474+
// Coroutine - get __await__ later
475+
wrapped.clone()
476+
} else {
477+
// Try to get __await__ method
478+
if let Some(await_method) = vm.get_method(wrapped.clone(), identifier!(vm, __await__)) {
479+
await_method?.call((), vm)?
480+
} else {
481+
return Err(vm.new_type_error(format!(
482+
"object {} can't be used in 'await' expression",
483+
wrapped.class().name()
484+
)));
485+
}
486+
};
487+
488+
// If awaitable is a coroutine, get its __await__
489+
if awaitable.class().is(vm.ctx.types.coroutine_type) {
490+
let coro_await = vm.call_method(&awaitable, "__await__", ())?;
491+
// Check that __await__ returned an iterator
492+
if !PyIter::check(&coro_await) {
493+
return Err(vm.new_type_error("__await__ returned a non-iterable"));
494+
}
495+
return Ok(coro_await);
496+
}
497+
498+
// Check the result is an iterator, not a coroutine
499+
if awaitable.downcast_ref::<PyCoroutine>().is_some() {
500+
return Err(vm.new_type_error("__await__() returned a coroutine"));
501+
}
502+
503+
// Check that the result is an iterator
504+
if !PyIter::check(&awaitable) {
505+
return Err(vm.new_type_error(format!(
506+
"__await__() returned non-iterator of type '{}'",
507+
awaitable.class().name()
508+
)));
509+
}
510+
511+
Ok(awaitable)
512+
}
513+
514+
#[pymethod]
515+
fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult {
516+
let awaitable = self.get_awaitable_iter(vm)?;
517+
let result = vm.call_method(&awaitable, "send", (val,));
518+
self.handle_result(result, vm)
519+
}
520+
521+
#[pymethod]
522+
fn throw(
523+
&self,
524+
exc_type: PyObjectRef,
525+
exc_val: OptionalArg,
526+
exc_tb: OptionalArg,
527+
vm: &VirtualMachine,
528+
) -> PyResult {
529+
let awaitable = self.get_awaitable_iter(vm)?;
530+
let result = vm.call_method(
531+
&awaitable,
532+
"throw",
533+
(
534+
exc_type,
535+
exc_val.unwrap_or_none(vm),
536+
exc_tb.unwrap_or_none(vm),
537+
),
538+
);
539+
self.handle_result(result, vm)
540+
}
541+
542+
#[pymethod]
543+
fn close(&self, vm: &VirtualMachine) -> PyResult<()> {
544+
if let Ok(awaitable) = self.get_awaitable_iter(vm) {
545+
let _ = vm.call_method(&awaitable, "close", ());
546+
}
547+
Ok(())
548+
}
549+
550+
/// Convert StopAsyncIteration to StopIteration(default_value)
551+
fn handle_result(&self, result: PyResult, vm: &VirtualMachine) -> PyResult {
552+
match result {
553+
Ok(value) => Ok(value),
554+
Err(exc) if exc.fast_isinstance(vm.ctx.exceptions.stop_async_iteration) => {
555+
Err(vm.new_stop_iteration(Some(self.default_value.clone())))
556+
}
557+
Err(exc) => Err(exc),
558+
}
559+
}
560+
}
561+
562+
impl SelfIter for PyAnextAwaitable {}
563+
impl IterNext for PyAnextAwaitable {
564+
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
565+
PyIterReturn::from_pyresult(zelf.send(vm.ctx.none(), vm), vm)
566+
}
567+
}
568+
427569
pub fn init(ctx: &Context) {
428570
PyAsyncGen::extend_class(ctx, ctx.types.async_generator);
429571
PyAsyncGenASend::extend_class(ctx, ctx.types.async_generator_asend);
430572
PyAsyncGenAThrow::extend_class(ctx, ctx.types.async_generator_athrow);
573+
PyAnextAwaitable::extend_class(ctx, ctx.types.anext_awaitable);
431574
}

crates/vm/src/builtins/coroutine.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,11 @@ impl PyCoroutineWrapper {
156156
) -> PyResult<PyIterReturn> {
157157
self.coro.throw(exc_type, exc_val, exc_tb, vm)
158158
}
159+
160+
#[pymethod]
161+
fn close(&self, vm: &VirtualMachine) -> PyResult<()> {
162+
self.coro.close(vm)
163+
}
159164
}
160165

161166
impl SelfIter for PyCoroutineWrapper {}

crates/vm/src/stdlib/builtins.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -540,12 +540,23 @@ mod builtins {
540540
default_value: OptionalArg<PyObjectRef>,
541541
vm: &VirtualMachine,
542542
) -> PyResult {
543+
use crate::builtins::asyncgenerator::PyAnextAwaitable;
544+
545+
// Check if object is an async iterator (has __anext__ method)
546+
if !aiter.class().has_attr(identifier!(vm, __anext__)) {
547+
return Err(vm.new_type_error(format!(
548+
"'{}' object is not an async iterator",
549+
aiter.class().name()
550+
)));
551+
}
552+
543553
let awaitable = vm.call_method(&aiter, "__anext__", ())?;
544554

545-
if default_value.is_missing() {
546-
Ok(awaitable)
555+
if let OptionalArg::Present(default) = default_value {
556+
Ok(PyAnextAwaitable::new(awaitable, default)
557+
.into_ref(&vm.ctx)
558+
.into())
547559
} else {
548-
// TODO: Implement CPython like PyAnextAwaitable to properly handle the default value.
549560
Ok(awaitable)
550561
}
551562
}

crates/vm/src/types/zoo.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ pub struct TypeZoo {
2121
pub async_generator_asend: &'static Py<PyType>,
2222
pub async_generator_athrow: &'static Py<PyType>,
2323
pub async_generator_wrapped_value: &'static Py<PyType>,
24+
pub anext_awaitable: &'static Py<PyType>,
2425
pub bytes_type: &'static Py<PyType>,
2526
pub bytes_iterator_type: &'static Py<PyType>,
2627
pub bytearray_type: &'static Py<PyType>,
@@ -139,6 +140,7 @@ impl TypeZoo {
139140
async_generator_athrow: asyncgenerator::PyAsyncGenAThrow::init_builtin_type(),
140141
async_generator_wrapped_value:
141142
asyncgenerator::PyAsyncGenWrappedValue::init_builtin_type(),
143+
anext_awaitable: asyncgenerator::PyAnextAwaitable::init_builtin_type(),
142144
bound_method_type: function::PyBoundMethod::init_builtin_type(),
143145
builtin_function_or_method_type: builtin_func::PyNativeFunction::init_builtin_type(),
144146
builtin_method_type: builtin_func::PyNativeMethod::init_builtin_type(),

0 commit comments

Comments
 (0)