@@ -424,8 +424,151 @@ impl IterNext for PyAsyncGenAThrow {
424424 }
425425}
426426
427+ /// Awaitable wrapper for anext() builtin with default value.
428+ /// When StopAsyncIteration is raised, it converts it to StopIteration(default).
429+ #[ pyclass( module = false , name = "anext_awaitable" ) ]
430+ #[ derive( Debug ) ]
431+ pub struct PyAnextAwaitable {
432+ wrapped : PyObjectRef ,
433+ default_value : PyObjectRef ,
434+ }
435+
436+ impl PyPayload for PyAnextAwaitable {
437+ #[ inline]
438+ fn class ( ctx : & Context ) -> & ' static Py < PyType > {
439+ ctx. types . anext_awaitable
440+ }
441+ }
442+
443+ #[ pyclass( with( IterNext , Iterable ) ) ]
444+ impl PyAnextAwaitable {
445+ pub fn new ( wrapped : PyObjectRef , default_value : PyObjectRef ) -> Self {
446+ Self {
447+ wrapped,
448+ default_value,
449+ }
450+ }
451+
452+ #[ pymethod( name = "__await__" ) ]
453+ fn r#await ( zelf : PyRef < Self > , _vm : & VirtualMachine ) -> PyRef < Self > {
454+ zelf
455+ }
456+
457+ /// Get the awaitable iterator from wrapped object.
458+ /// This is equivalent to CPython's anextawaitable_getiter.
459+ fn get_awaitable_iter ( & self , vm : & VirtualMachine ) -> PyResult {
460+ use crate :: builtins:: PyCoroutine ;
461+ use crate :: protocol:: PyIter ;
462+
463+ let wrapped = & self . wrapped ;
464+
465+ // If wrapped is already an async_generator_asend, it's an iterator
466+ if wrapped. class ( ) . is ( vm. ctx . types . async_generator_asend )
467+ || wrapped. class ( ) . is ( vm. ctx . types . async_generator_athrow )
468+ {
469+ return Ok ( wrapped. clone ( ) ) ;
470+ }
471+
472+ // _PyCoro_GetAwaitableIter equivalent
473+ let awaitable = if wrapped. class ( ) . is ( vm. ctx . types . coroutine_type ) {
474+ // Coroutine - get __await__ later
475+ wrapped. clone ( )
476+ } else {
477+ // Try to get __await__ method
478+ if let Some ( await_method) = vm. get_method ( wrapped. clone ( ) , identifier ! ( vm, __await__) ) {
479+ await_method?. call ( ( ) , vm) ?
480+ } else {
481+ return Err ( vm. new_type_error ( format ! (
482+ "object {} can't be used in 'await' expression" ,
483+ wrapped. class( ) . name( )
484+ ) ) ) ;
485+ }
486+ } ;
487+
488+ // If awaitable is a coroutine, get its __await__
489+ if awaitable. class ( ) . is ( vm. ctx . types . coroutine_type ) {
490+ let coro_await = vm. call_method ( & awaitable, "__await__" , ( ) ) ?;
491+ // Check that __await__ returned an iterator
492+ if !PyIter :: check ( & coro_await) {
493+ return Err ( vm. new_type_error ( "__await__ returned a non-iterable" ) ) ;
494+ }
495+ return Ok ( coro_await) ;
496+ }
497+
498+ // Check the result is an iterator, not a coroutine
499+ if awaitable. downcast_ref :: < PyCoroutine > ( ) . is_some ( ) {
500+ return Err ( vm. new_type_error ( "__await__() returned a coroutine" ) ) ;
501+ }
502+
503+ // Check that the result is an iterator
504+ if !PyIter :: check ( & awaitable) {
505+ return Err ( vm. new_type_error ( format ! (
506+ "__await__() returned non-iterator of type '{}'" ,
507+ awaitable. class( ) . name( )
508+ ) ) ) ;
509+ }
510+
511+ Ok ( awaitable)
512+ }
513+
514+ #[ pymethod]
515+ fn send ( & self , val : PyObjectRef , vm : & VirtualMachine ) -> PyResult {
516+ let awaitable = self . get_awaitable_iter ( vm) ?;
517+ let result = vm. call_method ( & awaitable, "send" , ( val, ) ) ;
518+ self . handle_result ( result, vm)
519+ }
520+
521+ #[ pymethod]
522+ fn throw (
523+ & self ,
524+ exc_type : PyObjectRef ,
525+ exc_val : OptionalArg ,
526+ exc_tb : OptionalArg ,
527+ vm : & VirtualMachine ,
528+ ) -> PyResult {
529+ let awaitable = self . get_awaitable_iter ( vm) ?;
530+ let result = vm. call_method (
531+ & awaitable,
532+ "throw" ,
533+ (
534+ exc_type,
535+ exc_val. unwrap_or_none ( vm) ,
536+ exc_tb. unwrap_or_none ( vm) ,
537+ ) ,
538+ ) ;
539+ self . handle_result ( result, vm)
540+ }
541+
542+ #[ pymethod]
543+ fn close ( & self , vm : & VirtualMachine ) -> PyResult < ( ) > {
544+ if let Ok ( awaitable) = self . get_awaitable_iter ( vm) {
545+ let _ = vm. call_method ( & awaitable, "close" , ( ) ) ;
546+ }
547+ Ok ( ( ) )
548+ }
549+
550+ /// Convert StopAsyncIteration to StopIteration(default_value)
551+ fn handle_result ( & self , result : PyResult , vm : & VirtualMachine ) -> PyResult {
552+ match result {
553+ Ok ( value) => Ok ( value) ,
554+ Err ( exc) if exc. fast_isinstance ( vm. ctx . exceptions . stop_async_iteration ) => {
555+ Err ( vm. new_stop_iteration ( Some ( self . default_value . clone ( ) ) ) )
556+ }
557+ Err ( exc) => Err ( exc) ,
558+ }
559+ }
560+ }
561+
562+ impl SelfIter for PyAnextAwaitable { }
563+ impl IterNext for PyAnextAwaitable {
564+ fn next ( zelf : & Py < Self > , vm : & VirtualMachine ) -> PyResult < PyIterReturn > {
565+ PyIterReturn :: from_pyresult ( zelf. send ( vm. ctx . none ( ) , vm) , vm)
566+ }
567+ }
568+
427569pub fn init ( ctx : & Context ) {
428570 PyAsyncGen :: extend_class ( ctx, ctx. types . async_generator ) ;
429571 PyAsyncGenASend :: extend_class ( ctx, ctx. types . async_generator_asend ) ;
430572 PyAsyncGenAThrow :: extend_class ( ctx, ctx. types . async_generator_athrow ) ;
573+ PyAnextAwaitable :: extend_class ( ctx, ctx. types . anext_awaitable ) ;
431574}
0 commit comments