Skip to content

Async Support for Reward Functions #4130

@mohamadlakkis

Description

@mohamadlakkis

Feature request

Add support in TRL for async reward functions so users can run batched external API calls (e.g. OpenAI/Deepseek or local inference) with asyncio.gather when computing rewards in the GRPO trainer.

Motivation

I am creating my reward functions for the GRPO trainer to have api calls to deepseek/openai (In fact this could be useful for any api calls, even for local servers/inference if your node setup supports it)

whenever I want to call my reward function over my batch, I need to do so in an async manner.

One Example of how I want to use it:

async def check_content(prompts, completions, **kwargs) -> list[float]:
# I put 1 to get the user content I do not care about the system content
prompts = [prompt[1]['content'] for prompt in prompts]
client = OpenAI(api_key=deepseek_api_key, base_url="https://api.deepseek.com")
async def call_model(prompt: str, completion:str) -> int:
"""Call Deepseek API to get relevance score, for one prompt-completion pair."""
response = await client.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": "You are a judge ...."},
{"role": "user", "content": f"This is the prompt: {prompt}"},
{"role": "user", "content": f"This is the answer: {completion}"},
],
stream=False
)
....
tasks = []
for prompt, completion in zip(prompts, completions):
tasks.append(call_model(prompt, completion[0]['content']))
scores = await asyncio.gather(*tasks)
return scores

Your contribution

I can do a PR, which I fixed it locally if people are interested

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions