From d65ed46271e836571477875f7d3b80ca8cdabf45 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Mon, 12 Jan 2026 17:51:47 +0800 Subject: [PATCH 1/2] Add openai client support for tinker backend --- scripts/context_length_test/README.md | 4 +- tests/common/vllm_test.py | 57 +++++++++++++-- trinity/common/config.py | 5 ++ trinity/common/models/__init__.py | 20 +++--- trinity/common/models/model.py | 84 ++++++++++++++++++++--- trinity/common/models/tinker_model.py | 99 ++++++++++++++++++++++++--- trinity/explorer/workflow_runner.py | 7 +- 7 files changed, 237 insertions(+), 39 deletions(-) diff --git a/scripts/context_length_test/README.md b/scripts/context_length_test/README.md index 10c7a535e2..de4c847fc6 100644 --- a/scripts/context_length_test/README.md +++ b/scripts/context_length_test/README.md @@ -124,7 +124,7 @@ Below are empirical results from running this script on various Qwen3 models acr ### A100 80GB -#### Vallina Settings (Baseline) +#### Vanilla Settings (Baseline) | #GPU | SP | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B | | ---- | -- | ---------- | ---------- | -------- | -------- | --------- | @@ -177,7 +177,7 @@ Below are empirical results from running this script on various Qwen3 models acr ### H20 96GB (Higher VRAM, Slower Bandwidth) -#### Vallina Settings +#### Vanilla Settings | #GPU | SP | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B | diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 4aeee3adb3..bdd1e4672a 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -669,10 +669,17 @@ async def test_logprobs_api(self): class TestAsyncAPIServer(RayUnittestBaseAsync): - def setUp(self): + engine_type: str = "vllm" + model_path: str = get_model_path() + + async def asyncSetUp(self): self.config = get_template_config() + self._update_config() + await self._setup_engines() + + def _update_config(self): self.config.mode = "explore" - self.config.model.model_path = get_model_path() + self.config.model.model_path = self.model_path self.config.explorer.rollout_model.engine_type = "vllm" self.config.explorer.rollout_model.engine_num = 1 self.config.explorer.rollout_model.tensor_parallel_size = 1 @@ -680,10 +687,14 @@ def setUp(self): self.config.explorer.rollout_model.enable_openai_api = True self.config.check_and_update() + + async def _setup_engines(self): self.engines, self.auxiliary_engines = create_inference_models(self.config) - self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True) + self.model_wrapper = ModelWrapper( + self.engines[0], engine_type=self.engine_type, enable_history=True + ) self.model_wrapper_no_history = ModelWrapper( - self.engines[0], engine_type="vllm", enable_history=False + self.engines[0], engine_type=self.engine_type, enable_history=False ) async def test_api_async(self): @@ -695,7 +706,7 @@ async def test_api_async(self): {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is your name?"}, ] - model_id = (await openai_client.models.list()).data[0].id + model_id = openai_client.model_path response = await openai_client.chat.completions.create( model=model_id, messages=messages, n=1 ) @@ -713,7 +724,8 @@ async def test_api_async(self): self.assertTrue(response.choices[0].logprobs is not None) self.assertEqual(0, len(response.choices[0].logprobs.content[2].top_logprobs)) # here we check the 3rd token logprob, because the first two tokens (``,`\n` usually have zero logprob) - self.assertTrue(response.choices[0].logprobs.content[2].logprob < 0) + if "Instruct" not in self.model_path: + self.assertTrue(response.choices[0].logprobs.content[2].logprob < 0) self.assertTrue(hasattr(response, "prompt_token_ids")) self.assertTrue(len(response.prompt_token_ids) > 0) self.assertTrue(hasattr(response.choices[0], "token_ids")) @@ -765,6 +777,39 @@ async def test_api_async(self): self.assertEqual(len(self.model_wrapper_no_history.history), 0) +class TestTinkerAsyncAPIServer(TestAsyncAPIServer): + engine_type: str = "tinker" + model_path: str = "Qwen/Qwen3-4B-Instruct-2507" + # llama model in Tinker does not support chat template + + def _update_config(self): + self.config.model.tinker.enable = True + self.config.algorithm.algorithm_type = "grpo" + super()._update_config() + from pprint import pprint + + pprint(self.config) + + async def _setup_engines(self): + import ray + + from trinity.common.config import Config + from trinity.manager.synchronizer import Synchronizer + + @ray.remote + class FakeTrainer: + def __init__(self, config: Config): + self.config = config + self.synchronizer = Synchronizer.get_actor(config) + + fake_trainer = FakeTrainer.remote(self.config) + await fake_trainer.__ray_ready__.remote() + await super()._setup_engines() + + async def test_api_async(self): + await super().test_api_async() + + class TestTokenizer(unittest.TestCase): def test_action_mask(self): messages = [ diff --git a/trinity/common/config.py b/trinity/common/config.py index 83b751fece..c80d84b61e 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -1218,6 +1218,11 @@ def _check_tinker(self) -> None: self.explorer.rollout_model.engine_type = "tinker" logger.warning("Rollout model engine type is set to `tinker`.") + for aux_model_config in self.explorer.auxiliary_models: + if aux_model_config.engine_type != "tinker": + aux_model_config.engine_type = "tinker" + logger.warning("Auxiliary model engine type is set to `tinker`.") + if self.trainer.trainer_type != "tinker": self.trainer.trainer_type = "tinker" logger.warning("Trainer type is set to `tinker`.") diff --git a/trinity/common/models/__init__.py b/trinity/common/models/__init__.py index 42674ea147..9dd77e6cb4 100644 --- a/trinity/common/models/__init__.py +++ b/trinity/common/models/__init__.py @@ -71,16 +71,18 @@ def create_inference_models( for i in range(engine_num) ] auxiliary_engines = [ - ray.remote(engine_cls) - .options( - name=f"{config.explorer.name}_auxiliary_model_{i}_{j}", - namespace=namespace, - ) - .remote( - config=config.explorer.auxiliary_models[i], - ) + [ + ray.remote(engine_cls) + .options( + name=f"{config.explorer.name}_auxiliary_model_{i}_{j}", + namespace=namespace, + ) + .remote( + config=config.explorer.auxiliary_models[i], + ) + for j in range(model_config.engine_num) + ] for i, model_config in enumerate(config.explorer.auxiliary_models) - for j in range(model_config.engine_num) ] return rollout_engines, auxiliary_engines else: diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index ac28276876..77839f7cd8 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -71,10 +71,18 @@ def get_api_server_url(self) -> Optional[str]: """Get the API server URL if available.""" return None + def get_api_key(self) -> str: + """Get the API key.""" + return "EMPTY" + def get_model_config(self) -> InferenceModelConfig: """Get the model configuration.""" return self.config + def get_model_path(self) -> Optional[str]: + """Get the model path""" + return self.config.model_path + def _history_recorder(func): """Decorator to record history of the model calls.""" @@ -118,10 +126,11 @@ def __init__( engine_type.startswith("vllm") or engine_type == "tinker" ), "Only vLLM and tinker model is supported for now." self.model = model + self.engine_type = engine_type self.config: InferenceModelConfig = None # init during prepare self._model_name: str = None - self._model_path: str = None self.api_address: str = None + self._api_key: str = None self.openai_client: openai.OpenAI = None self.openai_async_client: openai.AsyncOpenAI = None self.logger = get_logger(__name__) @@ -138,7 +147,7 @@ async def prepare(self) -> None: """Prepare the model wrapper.""" self.config = await self.model.get_model_config.remote() self._model_name = self.config.name - self._model_path = self.config.model_path + self._api_key = await self.model.get_api_key.remote() self._generate_kwargs = { "temperature": self.config.temperature, "top_p": self.config.top_p, @@ -152,6 +161,8 @@ async def prepare(self) -> None: if self.api_address is None: self.logger.info("API server is not enabled for inference model.") return + if self.engine_type == "tinker": + return max_retries = 30 interval = 2 # seconds for i in range(max_retries): @@ -285,6 +296,11 @@ async def convert_messages_to_experience_async( messages, tools=tools, temperature=temperature ) + @property + def api_key(self) -> str: + """Get the API key.""" + return self._api_key + @property def model_version(self) -> int: """Get the version of the model.""" @@ -298,7 +314,12 @@ async def model_version_async(self) -> int: @property def model_path(self) -> str: """Get the model path.""" - return self._model_path + return ray.get(self.model.get_model_path.remote()) + + @property + async def model_path_async(self) -> str: + """Get the model path.""" + return await self.model.get_model_path.remote() @property def model_name(self) -> Optional[str]: @@ -332,6 +353,7 @@ def get_openai_client(self) -> openai.OpenAI: openai.OpenAI: The openai client. And `model_path` is added to the client which refers to the model path. """ if self.openai_client is not None: + setattr(self.openai_client, "model_path", self.model_path) return self.openai_client if not self.api_address: raise ValueError( @@ -339,9 +361,30 @@ def get_openai_client(self) -> openai.OpenAI: ) self.openai_client = openai.OpenAI( base_url=f"{self.api_address}/v1", - api_key="EMPTY", + api_key=self._api_key, ) - if self.enable_history: + if self.engine_type == "tinker": + # ! TODO: because tinker's OpenAI API interface is in beta, + # we need to use original API in thinker instead. + ori_create = self.openai_async_client.chat.completions.create + + async def chat_completions(*args, **kwargs): + messages = kwargs.pop("messages") + chat_response = ray.get( + self.model.chat.remote( + messages=messages, + with_chat_completion=True, + return_token_ids=self.enable_history, + **kwargs, + ) + ) + response = chat_response.pop() + if self.enable_history: + self.history.extend(chat_response) + return response + + self.openai_async_client.chat.completions.create = chat_completions + elif self.enable_history: # add a decorator to the openai client to record history ori_create = self.openai_client.chat.completions.create @@ -359,7 +402,7 @@ def record_chat_completions(*args, **kwargs): return response self.openai_client.chat.completions.create = record_chat_completions - setattr(self.openai_client, "model_path", self.openai_client.models.list().data[0].id) + setattr(self.openai_client, "model_path", self.model_path) return self.openai_client def get_openai_async_client(self) -> openai.AsyncOpenAI: @@ -369,6 +412,7 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI: openai.AsyncOpenAI: The async openai client. And `model_path` is added to the client which refers to the model path. """ if self.openai_async_client is not None: + setattr(self.openai_async_client, "model_path", self.model_path) return self.openai_async_client if not self.api_address: raise ValueError( @@ -377,9 +421,29 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI: # first make sure that we have the sync openai client self.openai_async_client = openai.AsyncOpenAI( base_url=f"{self.api_address}/v1", - api_key="EMPTY", + api_key=self._api_key, ) - if self.enable_history: + + if self.engine_type == "tinker": + # ! TODO: because tinker's OpenAI API interface is in beta, + # we need to use original API in thinker instead. + ori_create = self.openai_async_client.chat.completions.create + + async def chat_completions(*args, **kwargs): + messages = kwargs.pop("messages") + chat_response = await self.model.chat.remote( + messages=messages, + with_chat_completion=True, + return_token_ids=self.enable_history, + **kwargs, + ) + response = chat_response.pop() + if self.enable_history: + self.history.extend(chat_response) + return response + + self.openai_async_client.chat.completions.create = chat_completions + elif self.enable_history: # add a decorator to the openai client to record history ori_create = self.openai_async_client.chat.completions.create @@ -400,8 +464,8 @@ async def record_chat_completions(*args, **kwargs): self.openai_async_client.chat.completions.create = record_chat_completions # get model_path from the sync openai client to avoid async call here - openai_client = self.get_openai_client() - setattr(self.openai_async_client, "model_path", openai_client.models.list().data[0].id) + # openai_client = self.get_openai_client() + setattr(self.openai_async_client, "model_path", self.model_path) return self.openai_async_client async def get_current_load(self) -> int: diff --git a/trinity/common/models/tinker_model.py b/trinity/common/models/tinker_model.py index 92451526aa..b94986d587 100644 --- a/trinity/common/models/tinker_model.py +++ b/trinity/common/models/tinker_model.py @@ -1,4 +1,6 @@ -from typing import List, Optional, Sequence +import time +from os import getenv +from typing import List, Optional import ray import tinker @@ -22,6 +24,7 @@ def __init__( self.model_version = -1 self.synchronizer = Synchronizer.get_actor(namespace=ray.get_runtime_context().namespace) self.model = None + self.model_path = config.model_path self.tokenizer = None self.chat_template = None if self.config.chat_template: @@ -31,10 +34,10 @@ def __init__( async def _initialize_tokenizer(self) -> None: """Initialize the tokenizer.""" - trainer_client = await self.service_client.create_lora_training_client_async( + self.trainer_client = await self.service_client.create_lora_training_client_async( base_model=self.config.model_path ) - self.tokenizer = trainer_client.get_tokenizer() + self.tokenizer = self.trainer_client.get_tokenizer() async def _generate_internal(self, prompt: dict, **kwargs) -> types.SampleResponse: assert self.model is not None @@ -54,7 +57,7 @@ async def _generate_internal(self, prompt: dict, **kwargs) -> types.SampleRespon topk_prompt_logprobs=kwargs.get("topk_prompt_logprobs", self.config.logprobs), ) - async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]: + async def generate(self, prompt: str, **kwargs) -> List[Experience]: """Generate a responses from a prompt in async.""" if self.tokenizer is None: await self._initialize_tokenizer() @@ -88,26 +91,83 @@ async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]: for _ in range(kwargs.get("n", 1)) ] + with_chat_completion = kwargs.get("with_chat_completion", False) + if with_chat_completion: + create_time = int(time.time()) output = await self._generate_internal(prompt={"prompt_token_ids": token_ids}, **kwargs) + return_logprobs = kwargs.get("logprobs", self.config.logprobs is not None) experiences = [ Experience( tokens=torch.tensor(token_ids + sequence.tokens, dtype=torch.int32), - logprobs=torch.tensor(sequence.logprobs, dtype=torch.float32), + logprobs=( + torch.tensor(sequence.logprobs, dtype=torch.float32) + if return_logprobs + else torch.tensor([], dtype=torch.float32) + ), prompt_length=len(token_ids), prompt_text=self.tokenizer.decode(token_ids), response_text=self.tokenizer.decode(sequence.tokens), ) for sequence in output.sequences ] + if with_chat_completion: + from openai.types.chat.chat_completion import ( + ChatCompletion, + ChatCompletionMessage, + ChatCompletionTokenLogprob, + Choice, + ChoiceLogprobs, + ) + + return_token_ids = kwargs.get("return_token_ids", False) + chat_completion = ChatCompletion( + id="", + choices=[ + Choice( + finish_reason=sequence.stop_reason, + index=i, + logprobs=ChoiceLogprobs( + content=[ + ChatCompletionTokenLogprob( + token=self.tokenizer.decode(token_id), + logprob=logprob, + top_logprobs=[], + ) + for token_id, logprob in zip(sequence.tokens, sequence.logprobs) + ] + ), + message=ChatCompletionMessage( + content=self.tokenizer.decode(sequence.tokens), role="assistant" + ), + token_ids=(sequence.tokens if return_token_ids else None), + ) + for i, sequence in enumerate(output.sequences) + ], + created=create_time, + model=self.model_path, + object="chat.completion", + prompt_token_ids=token_ids, + ) + experiences.append(chat_completion) return experiences - async def chat(self, messages: List[dict], **kwargs) -> Sequence[Experience]: + async def chat(self, messages: List[dict], **kwargs) -> List[Experience]: """Generate experiences from a list of history chat messages in async.""" if self.tokenizer is None: await self._initialize_tokenizer() if self.chat_template is None: self.chat_template = self.tokenizer.get_chat_template() + + # TODO: this is a hack to support openai chat messages, which only supports text + for msg in messages: + if isinstance(msg["content"], list): + text_parts = [item["text"] for item in msg["content"] if item["type"] == "text"] + content_str = "".join(text_parts) + else: + content_str = msg["content"] + msg["content"] = content_str + if messages[-1]["role"] == "assistant": prompt = self.tokenizer.apply_chat_template( messages, @@ -127,7 +187,7 @@ async def chat(self, messages: List[dict], **kwargs) -> Sequence[Experience]: async def logprobs(self, token_ids: List[int], **kwargs) -> Tensor: """Generate logprobs for a list of tokens in async.""" - logprobs = await self.model.compute_logprobs_async(types.ModelInput(token_ids)) + logprobs = await self.model.compute_logprobs_async(types.ModelInput.from_ints(token_ids)) return torch.tensor(logprobs[1:], dtype=torch.float32) async def convert_messages_to_experience( @@ -180,6 +240,7 @@ async def prepare(self) -> None: self.model = await self.service_client.create_sampling_client_async( base_model=self.config.model_path, ) + await self._initialize_tokenizer() async def sync_model(self, model_version: int) -> int: self.model_version = model_version @@ -187,6 +248,7 @@ async def sync_model(self, model_version: int) -> int: self.model = await self.service_client.create_sampling_client_async( model_path=remote_sampler_path, ) + self.model_path = remote_sampler_path return model_version def get_model_version(self) -> int: @@ -194,6 +256,23 @@ def get_model_version(self) -> int: return self.model_version def get_api_server_url(self) -> Optional[str]: - """Get the API server URL if available.""" - # TODO: tinker will support openai api later - return None + """ + Get the Tinker Openai API interface URL. + + Documentation: https://tinker-docs.thinkingmachines.ai/compatible-apis/openai + + Note: This URL is currently not in active use because Tinker's OpenAI-compatible + API implementation is still incomplete. Instead, we're using our custom `self.chat()` + method to replicate the functionality of `openai.OpenAI.chat.completions.create()`. + + Once Tinker's API is fully implemented and stable, we plan to switch to using this + official endpoint directly. + """ + return "https://tinker.thinkingmachines.dev/services/tinker-prod/oai/api/" + + def get_api_key(self): + return getenv("TINKER_API_KEY") + + def get_model_path(self) -> Optional[str]: + """Get the model path""" + return self.model_path diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 9edba9014f..aeadba1ffb 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -61,7 +61,7 @@ def __init__( self.model = model self.model_wrapper = ModelWrapper( model, - config.explorer.rollout_model.engine_type, + engine_type=config.explorer.rollout_model.engine_type, enable_lora=config.explorer.rollout_model.enable_lora, enable_history=config.explorer.rollout_model.enable_history, ) @@ -69,8 +69,11 @@ def __init__( self.auxiliary_model_wrappers = [ ModelWrapper( model, + engine_type=aux_model_config.engine_type, + ) + for model, aux_model_config in zip( + self.auxiliary_models, config.explorer.auxiliary_models ) - for model in self.auxiliary_models ] self.workflow_instance: Workflow = None self.runner_id = runner_id From 8b2c0f6304f89874576b635ec2a493c728dad84b Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Mon, 12 Jan 2026 18:26:18 +0800 Subject: [PATCH 2/2] apply reviews --- tests/common/vllm_test.py | 12 ++++-------- tests/trainer/trainer_test.py | 2 +- trinity/common/models/model.py | 23 ++++++++++++++--------- trinity/common/models/tinker_model.py | 8 ++++---- 4 files changed, 23 insertions(+), 22 deletions(-) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index bdd1e4672a..78d465c452 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -2,6 +2,7 @@ import os import unittest +import ray import torch from openai import BadRequestError from parameterized import parameterized_class @@ -13,12 +14,14 @@ get_model_path, get_template_config, ) +from trinity.common.config import Config from trinity.common.models import create_inference_models from trinity.common.models.model import ModelWrapper from trinity.common.models.utils import ( tokenize_and_mask_messages_default, tokenize_and_mask_messages_hf, ) +from trinity.manager.synchronizer import Synchronizer DEBUG = False @@ -777,6 +780,7 @@ async def test_api_async(self): self.assertEqual(len(self.model_wrapper_no_history.history), 0) +@unittest.skipIf("TINKER_API_KEY" not in os.environ, "TINKER_API_KEY is not set") class TestTinkerAsyncAPIServer(TestAsyncAPIServer): engine_type: str = "tinker" model_path: str = "Qwen/Qwen3-4B-Instruct-2507" @@ -786,16 +790,8 @@ def _update_config(self): self.config.model.tinker.enable = True self.config.algorithm.algorithm_type = "grpo" super()._update_config() - from pprint import pprint - - pprint(self.config) async def _setup_engines(self): - import ray - - from trinity.common.config import Config - from trinity.manager.synchronizer import Synchronizer - @ray.remote class FakeTrainer: def __init__(self, config: Config): diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 9626b8c4b6..3cd4c8f856 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -1402,7 +1402,7 @@ def tearDown(self): class TestTinkerTrainer(BaseTrainerCase): - @unittest.skip("Require tinker API key") + @unittest.skipIf("TINKER_API_KEY" not in os.environ, "TINKER_API_KEY is not set") def test_trainer(self): """Test GSM8K on tinker.""" # test both mode diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 77839f7cd8..e0534518e0 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -313,12 +313,22 @@ async def model_version_async(self) -> int: @property def model_path(self) -> str: - """Get the model path.""" + """ + Returns the path to the model files based on the current engine type. + + - For 'vllm' engine: returns the model path from the configuration (`config.model_path`) + - For 'tinker' engine: returns the path to the most recent sampler weights + """ return ray.get(self.model.get_model_path.remote()) @property async def model_path_async(self) -> str: - """Get the model path.""" + """ + Returns the path to the model files based on the current engine type. + + - For 'vllm' engine: returns the model path from the configuration (`config.model_path`) + - For 'tinker' engine: returns the path to the most recent sampler weights + """ return await self.model.get_model_path.remote() @property @@ -366,9 +376,7 @@ def get_openai_client(self) -> openai.OpenAI: if self.engine_type == "tinker": # ! TODO: because tinker's OpenAI API interface is in beta, # we need to use original API in thinker instead. - ori_create = self.openai_async_client.chat.completions.create - - async def chat_completions(*args, **kwargs): + def chat_completions(*args, **kwargs): messages = kwargs.pop("messages") chat_response = ray.get( self.model.chat.remote( @@ -383,7 +391,7 @@ async def chat_completions(*args, **kwargs): self.history.extend(chat_response) return response - self.openai_async_client.chat.completions.create = chat_completions + self.openai_client.chat.completions.create = chat_completions elif self.enable_history: # add a decorator to the openai client to record history @@ -427,8 +435,6 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI: if self.engine_type == "tinker": # ! TODO: because tinker's OpenAI API interface is in beta, # we need to use original API in thinker instead. - ori_create = self.openai_async_client.chat.completions.create - async def chat_completions(*args, **kwargs): messages = kwargs.pop("messages") chat_response = await self.model.chat.remote( @@ -464,7 +470,6 @@ async def record_chat_completions(*args, **kwargs): self.openai_async_client.chat.completions.create = record_chat_completions # get model_path from the sync openai client to avoid async call here - # openai_client = self.get_openai_client() setattr(self.openai_async_client, "model_path", self.model_path) return self.openai_async_client diff --git a/trinity/common/models/tinker_model.py b/trinity/common/models/tinker_model.py index b94986d587..0be19b855f 100644 --- a/trinity/common/models/tinker_model.py +++ b/trinity/common/models/tinker_model.py @@ -1,6 +1,6 @@ import time from os import getenv -from typing import List, Optional +from typing import List, Optional, Sequence import ray import tinker @@ -57,7 +57,7 @@ async def _generate_internal(self, prompt: dict, **kwargs) -> types.SampleRespon topk_prompt_logprobs=kwargs.get("topk_prompt_logprobs", self.config.logprobs), ) - async def generate(self, prompt: str, **kwargs) -> List[Experience]: + async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]: """Generate a responses from a prompt in async.""" if self.tokenizer is None: await self._initialize_tokenizer() @@ -152,7 +152,7 @@ async def generate(self, prompt: str, **kwargs) -> List[Experience]: return experiences - async def chat(self, messages: List[dict], **kwargs) -> List[Experience]: + async def chat(self, messages: List[dict], **kwargs) -> Sequence[Experience]: """Generate experiences from a list of history chat messages in async.""" if self.tokenizer is None: await self._initialize_tokenizer() @@ -274,5 +274,5 @@ def get_api_key(self): return getenv("TINKER_API_KEY") def get_model_path(self) -> Optional[str]: - """Get the model path""" + """Get the latest sampler weight path.""" return self.model_path