-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Description
Feature request
Add support in TRL for async reward functions so users can run batched external API calls (e.g. OpenAI/Deepseek or local inference) with asyncio.gather when computing rewards in the GRPO trainer.
Motivation
I am creating my reward functions for the GRPO trainer to have api calls to deepseek/openai (In fact this could be useful for any api calls, even for local servers/inference if your node setup supports it)
whenever I want to call my reward function over my batch, I need to do so in an async manner.
One Example of how I want to use it:
async def check_content(prompts, completions, **kwargs) -> list[float]:
# I put 1 to get the user content I do not care about the system content
prompts = [prompt[1]['content'] for prompt in prompts]
client = OpenAI(api_key=deepseek_api_key, base_url="https://api.deepseek.com")
async def call_model(prompt: str, completion:str) -> int:
"""Call Deepseek API to get relevance score, for one prompt-completion pair."""
response = await client.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": "You are a judge ...."},
{"role": "user", "content": f"This is the prompt: {prompt}"},
{"role": "user", "content": f"This is the answer: {completion}"},
],
stream=False
)
....
tasks = []
for prompt, completion in zip(prompts, completions):
tasks.append(call_model(prompt, completion[0]['content']))
scores = await asyncio.gather(*tasks)
return scores
Your contribution
I can do a PR, which I fixed it locally if people are interested