diff --git a/README.md b/README.md index fe6a2c676b..4e68842c90 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,7 @@ Trinity-RFT provides functionalities for users with different backgrounds and ob ## 🚀 News +* [2026-01] [[Release Notes]](https://github.com/modelscope/Trinity-RFT/releases/tag/v0.4.1) Trinity-RFT v0.4.1 released: upgraded verl to v0.7.0, Tinker backend supports OpenAI API, bug fixes. * [2026-01] Introducing [R3L](https://github.com/shiweijiezero/R3L): a systematic reflect-then-retry RL mechanism with efficient language-guided exploration and stable off-policy learning ([paper](https://arxiv.org/abs/2601.03715)). * [2025-12] [[Release Notes]](https://github.com/modelscope/Trinity-RFT/releases/tag/v0.4.0) Trinity-RFT v0.4.0 released: added [Tinker](https://thinkingmachines.ai/tinker/) backend for users **without GPUs**, add more benchmarks, enhance online RL and more. * [2025-12] Trinity-RFT powers the medical and health business of "Taobao Shangou", enabling the AI agent to understand vague symptoms, proactively ask follow-up questions, and provide precise recommendations ([News](https://tech.china.com.cn/sx/20251201/411376.shtml)). diff --git a/README_zh.md b/README_zh.md index 4638f55945..305f89b68f 100644 --- a/README_zh.md +++ b/README_zh.md @@ -41,6 +41,7 @@ Trinity-RFT 面向不同背景和目标的用户提供相应功能: ## 🚀 新闻 +* [2026-01] [[发布说明]](https://github.com/modelscope/Trinity-RFT/releases/tag/v0.4.1) Trinity-RFT v0.4.1 发布:升级 verl 至 v0.7.0,Tinker 后端支持 OpenAI API,修复若干 Bug。 * [2026-01] 推出 [R3L](https://github.com/shiweijiezero/R3L):基于反思-重试的强化学习机制,由自然语言反馈引导高效探索,并达成稳定的 off-policy 学习([论文](https://arxiv.org/abs/2601.03715))。 * [2025-12] [[发布说明]](https://github.com/modelscope/Trinity-RFT/releases/tag/v0.4.0) Trinity-RFT v0.4.0 发布:新增[Tinker](https://thinkingmachines.ai/tinker/) 后端以支持在 **无 GPU** 的设备上训练,增加更多基准测试,增强在线 RL 等功能。 * [2025-12] Trinity-RFT 已支持 [tinker](https://thinkingmachines.ai/tinker/) 训练后端,可在**无 GPU 的设备**上进行模型训练。 diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 19dda755bd..c0e16d0645 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -97,7 +97,7 @@ algorithm: repeat_times: 8 optimizer: lr: 1e-6 - warmup_style: "warmup" + lr_scheduler_type: "constant" # The following parameters are optional # If not specified, they will automatically be set based on the `algorithm_type` sample_strategy: "default" @@ -111,7 +111,8 @@ algorithm: - `repeat_times`: Number of times each task is repeated. Default is `1`. In `dpo`, this is automatically set to `2`. Some algorithms such as GRPO and OPMD require `repeat_times` > 1. - `optimizer`: Optimizer configuration for actor. - `lr`: Learning rate for actor. - - `warmup_style`: Warmup style for actor's learning rate. + - `warmup_style`: Deprecated, use `lr_scheduler_type` instead. We will remove this field in future versions. + - `lr_scheduler_type`: Learning rate scheduler type for actor model. Default is `constant`. Supported types: `constant`, `consine`. - `sample_strategy`: The sampling strategy used for loading experiences from experience buffer. Supported types: `default`, `staleness_control`, `mix`. - `advantage_fn`: The advantage function used for computing advantages. - `kl_penalty_fn`: The KL penalty function used for computing KL penalty applied in reward. diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md index d130517920..8f00929da5 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md @@ -97,7 +97,7 @@ algorithm: repeat_times: 8 optimizer: lr: 1e-6 - warmup_style: constant + lr_scheduler_type: constant # 以下参数为可选 # 若未指定,将根据 `algorithm_type` 自动设置 sample_strategy: "default" @@ -111,7 +111,8 @@ algorithm: - `repeat_times`: 每个任务重复的次数。默认为 `1`。在 `dpo` 中自动设为 `2`。某些算法如 GRPO 和 OPMD 要求 `repeat_times` > 1。 - `optimizer`: Actor 优化器的参数。 - `lr`: 优化器的学习率。 - - `warmup_style`: 学习率的预热策略。 + - `warmup_style`:已弃用,请改用 `lr_scheduler_type`。该域将会在未来版本中移除。 + - `lr_scheduler_type`:Actor 模型的学习率调度器类型。默认值为 `constant`。支持类型:`constant`、`cosine`。 - `sample_strategy`: 从 experience buffer 加载 experience 时使用的采样策略。支持类型:`default`、`staleness_control`、`mix`。 - `advantage_fn`: 用于计算优势值的函数。 - `kl_penalty_fn`: 用于在奖励中计算 KL 惩罚的函数。 diff --git a/pyproject.toml b/pyproject.toml index 624519e62c..b4342cc4d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "trinity-rft" -version = "0.4.0" +version = "0.4.1" authors = [ {name="Trinity-RFT Team", email="trinity-rft@outlook.com"}, ] @@ -88,7 +88,7 @@ tinker = [ ] doc = [ - "sphinx", + "sphinx<9.0.0", "sphinx-autobuild", "sphinx-book-theme", "myst-parser", diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 78d465c452..3a40988a07 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -1270,7 +1270,6 @@ def setUp(self): self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE self.config.explorer.rollout_model.enable_openai_api = True self.config.explorer.rollout_model.enable_lora = True - self.config.explorer.rollout_model.enable_runtime_lora_updating = True self.config.check_and_update() self.engines, self.auxiliary_engines = create_inference_models(self.config) @@ -1345,3 +1344,68 @@ async def test_tinker_api(self): self.assertEqual(response.sequences[0].stop_reason, "length") self.assertEqual(len(prompt.to_ints()), len(response.prompt_logprobs)) self.assertIsNone(response.topk_prompt_logprobs) + + # test add remove lora + from vllm.lora.request import LoRARequest + + # create a dummy lora adapter with all zero weights + lora_path_1 = os.path.join(self.config.checkpoint_job_dir, "adapter_1") + lora_path_2 = os.path.join(self.config.checkpoint_job_dir, "adapter_2") + _create_adapter(self.config.model.model_path, lora_path_1, "adapter_1") + _create_adapter(self.config.model.model_path, lora_path_2, "adapter_2") + lora_1 = LoRARequest( + lora_name="test_adapter_1", + lora_int_id=1, + lora_path=os.path.join(lora_path_1, "adapter_1"), + ) + lora_2 = LoRARequest( + lora_name="test_adapter_2", + lora_int_id=2, + lora_path=os.path.join(lora_path_2, "adapter_2"), + ) + response = await engine.sample.remote( + prompt=prompt, + num_samples=1, + sampling_params=types.SamplingParams(max_tokens=1), + include_prompt_logprobs=True, + lora_request=lora_1, + ) + ids = await engine.list_lora_adapters.remote() + self.assertEqual(ids, [1]) + self.assertEqual(len(response.sequences), 1) + self.assertEqual(response.sequences[0].stop_reason, "length") + self.assertEqual(len(prompt.to_ints()), len(response.prompt_logprobs)) + self.assertIsNone(response.topk_prompt_logprobs) + response = await engine.sample.remote( + prompt=prompt, + num_samples=1, + sampling_params=types.SamplingParams(max_tokens=1), + include_prompt_logprobs=True, + lora_request=lora_2, + ) + self.assertEqual(len(response.sequences), 1) + self.assertEqual(response.sequences[0].stop_reason, "length") + self.assertEqual(len(prompt.to_ints()), len(response.prompt_logprobs)) + self.assertIsNone(response.topk_prompt_logprobs) + await engine.remove_lora_adapter.remote(lora_id=1) + await engine.remove_lora_adapter.remote(lora_id=2) + ids = await engine.list_lora_adapters.remote() + self.assertEqual(ids, []) + + +def _create_adapter(model_path: str, lora_path: str, name: str): + from peft import LoraConfig, get_peft_model + from transformers import AutoModelForCausalLM + + model = AutoModelForCausalLM.from_pretrained( + model_path, + device_map="cpu", + ) + lora_config = LoraConfig( + r=8, + lora_alpha=8, + target_modules=["gate_proj", "up_proj", "down_proj"], + lora_dropout=0.1, + ) + lora_model = get_peft_model(model, lora_config, adapter_name=name) + lora_model.save_pretrained(lora_path) diff --git a/trinity/__init__.py b/trinity/__init__.py index 26314a1a39..ef49c83930 100644 --- a/trinity/__init__.py +++ b/trinity/__init__.py @@ -1,4 +1,4 @@ # -*- coding: utf-8 -*- """Trinity-RFT (Reinforcement Fine-Tuning)""" -__version__ = "0.4.0" +__version__ = "0.4.1" diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 0fef9c5d4e..6c01cddc71 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -403,6 +403,35 @@ async def logprobs( # type: ignore [override] dtype=torch.float32, ) + async def add_lora_adapter(self, lora_request: Any) -> int: + """Add a LoRA adapter to the vLLM engine. + + Args: + lora_request (LoRARequest): The LoRA request. + + Returns: + lora_id (int): The LoRA adapter ID. + """ + lora_id = await self.async_llm.add_lora(lora_request) + return lora_id + + async def remove_lora_adapter(self, lora_id: int) -> None: + """Remove a LoRA adapter from the vLLM engine. + + Args: + lora_id (int): The LoRA adapter ID. + """ + await self.async_llm.remove_lora(lora_id) + + async def list_lora_adapters(self) -> Sequence[int]: + """List all LoRA adapter IDs in the vLLM engine. + + Returns: + lora_ids (List[int]): The list of LoRA adapter IDs. + """ + lora_ids = await self.async_llm.list_loras() + return list(lora_ids) + async def sample( self, prompt: Any,