Skip to content

fix logprobs handling#5198

Open
winglian wants to merge 1 commit intohuggingface:mainfrom
winglian:vllm-generate-logprobs
Open

fix logprobs handling#5198
winglian wants to merge 1 commit intohuggingface:mainfrom
winglian:vllm-generate-logprobs

Conversation

@winglian
Copy link
Contributor

What does this PR do?

#5107 introduced a bug in GRPOTrainer whereby it sets logprobs to 0, but then attempts to reduce over the logprobs array when vllm's response structure is dependent on that parameter.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

- `completion_ids` (`list[list[int]]`):
List of lists of token IDs representing the model-generated completions for each prompt.
- `logprobs` (`list[list[list[float]]]`):
- `logprobs` (`list[list[list[float]]]` | `list[list[float]]`):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no so sure about this change, I think it's always list[list[list[float]]]:

>>> client.generate(["Hello, AI!", "Tell me a joke"], logprobs=0)["logprobs"]
[[[-0.2079554945230484], [-5.021359443664551], [-4.506778717041016], [-0.16933049261569977], [-1.328775405883789], [-5.707622051239014], [-6.522100925445557], [-1.3067556619644165], [-6.344869136810303], [-0.061117831617593765], [-1.44622802734375], [-0.04607903212308884], [-0.00957468245178461], [-3.2259726524353027], [-3.274900436401367], [-3.4954776763916016]], [[-0.0422004759311676], [-5.590917587280273], [-1.9313716888427734], [-1.106265664100647], [-0.0110595328733325], [-0.0010186012368649244], [-1.6689160474925302e-05], [-1.3351351299206726e-05], [-0.11895198374986649], [-0.0006528153317049146], [-1.1920922133867862e-06], [-0.002003330737352371], [-0.003181754844263196], [-0.049776118248701096], [-0.011047743260860443], [-0.8712134957313538]]]
>>> client.generate(["Hello, AI!", "Tell me a joke"], logprobs=0)["logprobs"]
[[[-0.2079554945230484], [-0.7713592648506165, -1.6463592052459717], [-0.4071059823036194], [-1.109581470489502, -1.609581470489502], [-0.06224556267261505], [-1.5437123775482178, -3.6687123775482178], [-1.0006335973739624, -1.5006335973739624], [-0.1264881044626236], [-0.006634233985096216], [-0.3543497920036316], [-0.8073388934135437], [-0.2931288480758667], [-0.2984274625778198], [-1.352670669555664, -1.977670669555664], [-0.2433444708585739, -1.9933444261550903], [-0.45821458101272583, -3.833214521408081]], [[-0.0422004759311676], [-2.3409173488616943, -6.840917587280273], [-0.7118824124336243], [-2.3060879707336426, -7.743587970733643], [-0.8871606588363647], [-1.029018759727478], [-0.04861651360988617], [-0.0002108589978888631], [-2.753696753643453e-05], [-1.4305012882687151e-05], [-0.05704395845532417], [-0.01012428104877472], [-0.00042632073746062815], [0.0], [-2.622600959512056e-06], [-0.004123874939978123]]]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait a sec, isn't it a behavioural change from vLLM?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if you don't pass logprobs at all?

Copy link
Member

@qgallouedec qgallouedec Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

server side error:

>>> client.generate(["Hello, AI!", "Tell me a joke"], logprobs=None)["logprobs"]
ERROR:    Exception in ASGI application
Traceback (most recent call last):
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/uvicorn/protocols/http/httptools_impl.py", line 416, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        self.scope, self.receive, self.send
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/uvicorn/middleware/proxy_headers.py", line 60, in __call__
    return await self.app(scope, receive, send)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/fastapi/applications.py", line 1135, in __call__
    await super().__call__(scope, receive, send)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/applications.py", line 107, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/middleware/errors.py", line 186, in __call__
    raise exc
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/middleware/errors.py", line 164, in __call__
    await self.app(scope, receive, _send)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/middleware/exceptions.py", line 63, in __call__
    await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
    raise exc
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/_exception_handler.py", line 42, in wrapped_app
    await app(scope, receive, sender)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/fastapi/middleware/asyncexitstack.py", line 18, in __call__
    await self.app(scope, receive, send)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/routing.py", line 716, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/routing.py", line 736, in app
    await route.handle(scope, receive, send)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/routing.py", line 290, in handle
    await self.app(scope, receive, send)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/fastapi/routing.py", line 115, in app
    await wrap_app_handling_exceptions(app, request)(scope, receive, send)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
    raise exc
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/_exception_handler.py", line 42, in wrapped_app
    await app(scope, receive, sender)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/fastapi/routing.py", line 101, in app
    response = await f(request)
               ^^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/fastapi/routing.py", line 377, in app
    content = await serialize_response(
              ^^^^^^^^^^^^^^^^^^^^^^^^^
    ...<10 lines>...
    )
    ^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/fastapi/routing.py", line 215, in serialize_response
    raise ResponseValidationError(
    ...<3 lines>...
    )
fastapi.exceptions.ResponseValidationError: 2 validation errors:
  {'type': 'list_type', 'loc': ('response', 'logprobs'), 'msg': 'Input should be a valid list', 'input': None}
  {'type': 'list_type', 'loc': ('response', 'logprob_token_ids'), 'msg': 'Input should be a valid list', 'input': None}

  File "/fsx/qgallouedec/trl/trl/scripts/vllm_serve.py", line 512, in generate
    POST /generate/

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or

>>> client.generate(["Hello, AI!", "Tell me a joke"])["logprobs"]
[[[-0.2079554945230484], [-1.6463592052459717], [-0.4071059823036194], [-1.609581470489502], [-0.06224556267261505], [-1.5437123775482178], [-2.3698196411132812], [-0.29048386216163635], [-0.2877020239830017], [-2.652569532394409], [-0.4416201710700989], [-10.628130912780762], [-1.1008961200714111], [-0.37943634390830994], [-0.7756030559539795], [-14.819398880004883]], [[-0.0422004759311676], [-5.340917587280273], [-0.9934616088867188], [-1.7789040803909302], [-0.009760985150933266], [-0.00201974855735898], [-1.4781842764932662e-05], [-5.8412379075889476e-06], [-0.07062072306871414], [-0.0012309125158935785], [0.0], [-0.018214812502264977], [-0.001459129503928125], [-0.5814141035079956], [-0.01416344940662384], [-0.0002592465898487717]]]

Copy link
Collaborator

@LeonEricsson LeonEricsson Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, i think it's always list[list[list[float]]]. Or actually, it should be list[list[list[float]]] | None, and vllm_serve should correctly return None when passing logprobs=None, right now it crashes as Quentin noted.

)
# vLLM returns per-token top-k logprobs; keep only the top-1 (sampled token) logprob
logprobs = [[lp[0] for lp in seq] for seq in logprobs]
if isinstance(logprobs[0][0], list):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't "really" like this check since it depends on a lot of nesting. Open to a better way to handle this guard.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants