diff --git a/qa/L1_pytorch_onnx_unittest/test.sh b/qa/L1_pytorch_onnx_unittest/test.sh index 6f9ff54e48..0edf92c475 100644 --- a/qa/L1_pytorch_onnx_unittest/test.sh +++ b/qa/L1_pytorch_onnx_unittest/test.sh @@ -2,9 +2,15 @@ # # See LICENSE for license information. +function error_exit() { + echo "Error: $1" + exit 1 +} + : ${TE_PATH:=/opt/transformerengine} : ${XML_LOG_DIR:=/logs} mkdir -p "$XML_LOG_DIR" +pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" # NVTE_UnfusedDPA_Emulate_FP8=1 enables FP8 attention emulation when no native backend is available NVTE_UnfusedDPA_Emulate_FP8=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py