Conversation
| - `completion_ids` (`list[list[int]]`): | ||
| List of lists of token IDs representing the model-generated completions for each prompt. | ||
| - `logprobs` (`list[list[list[float]]]`): | ||
| - `logprobs` (`list[list[list[float]]]` | `list[list[float]]`): |
There was a problem hiding this comment.
no so sure about this change, I think it's always list[list[list[float]]]:
>>> client.generate(["Hello, AI!", "Tell me a joke"], logprobs=0)["logprobs"]
[[[-0.2079554945230484], [-5.021359443664551], [-4.506778717041016], [-0.16933049261569977], [-1.328775405883789], [-5.707622051239014], [-6.522100925445557], [-1.3067556619644165], [-6.344869136810303], [-0.061117831617593765], [-1.44622802734375], [-0.04607903212308884], [-0.00957468245178461], [-3.2259726524353027], [-3.274900436401367], [-3.4954776763916016]], [[-0.0422004759311676], [-5.590917587280273], [-1.9313716888427734], [-1.106265664100647], [-0.0110595328733325], [-0.0010186012368649244], [-1.6689160474925302e-05], [-1.3351351299206726e-05], [-0.11895198374986649], [-0.0006528153317049146], [-1.1920922133867862e-06], [-0.002003330737352371], [-0.003181754844263196], [-0.049776118248701096], [-0.011047743260860443], [-0.8712134957313538]]]
>>> client.generate(["Hello, AI!", "Tell me a joke"], logprobs=0)["logprobs"]
[[[-0.2079554945230484], [-0.7713592648506165, -1.6463592052459717], [-0.4071059823036194], [-1.109581470489502, -1.609581470489502], [-0.06224556267261505], [-1.5437123775482178, -3.6687123775482178], [-1.0006335973739624, -1.5006335973739624], [-0.1264881044626236], [-0.006634233985096216], [-0.3543497920036316], [-0.8073388934135437], [-0.2931288480758667], [-0.2984274625778198], [-1.352670669555664, -1.977670669555664], [-0.2433444708585739, -1.9933444261550903], [-0.45821458101272583, -3.833214521408081]], [[-0.0422004759311676], [-2.3409173488616943, -6.840917587280273], [-0.7118824124336243], [-2.3060879707336426, -7.743587970733643], [-0.8871606588363647], [-1.029018759727478], [-0.04861651360988617], [-0.0002108589978888631], [-2.753696753643453e-05], [-1.4305012882687151e-05], [-0.05704395845532417], [-0.01012428104877472], [-0.00042632073746062815], [0.0], [-2.622600959512056e-06], [-0.004123874939978123]]]There was a problem hiding this comment.
wait a sec, isn't it a behavioural change from vLLM?
There was a problem hiding this comment.
what if you don't pass logprobs at all?
There was a problem hiding this comment.
server side error:
>>> client.generate(["Hello, AI!", "Tell me a joke"], logprobs=None)["logprobs"]ERROR: Exception in ASGI application
Traceback (most recent call last):
File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/uvicorn/protocols/http/httptools_impl.py", line 416, in run_asgi
result = await app( # type: ignore[func-returns-value]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
self.scope, self.receive, self.send
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/uvicorn/middleware/proxy_headers.py", line 60, in __call__
return await self.app(scope, receive, send)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/fastapi/applications.py", line 1135, in __call__
await super().__call__(scope, receive, send)
File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/applications.py", line 107, in __call__
await self.middleware_stack(scope, receive, send)
File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/middleware/errors.py", line 186, in __call__
raise exc
File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/middleware/errors.py", line 164, in __call__
await self.app(scope, receive, _send)
File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/middleware/exceptions.py", line 63, in __call__
await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
raise exc
File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/_exception_handler.py", line 42, in wrapped_app
await app(scope, receive, sender)
File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/fastapi/middleware/asyncexitstack.py", line 18, in __call__
await self.app(scope, receive, send)
File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/routing.py", line 716, in __call__
await self.middleware_stack(scope, receive, send)
File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/routing.py", line 736, in app
await route.handle(scope, receive, send)
File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/routing.py", line 290, in handle
await self.app(scope, receive, send)
File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/fastapi/routing.py", line 115, in app
await wrap_app_handling_exceptions(app, request)(scope, receive, send)
File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
raise exc
File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/starlette/_exception_handler.py", line 42, in wrapped_app
await app(scope, receive, sender)
File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/fastapi/routing.py", line 101, in app
response = await f(request)
^^^^^^^^^^^^^^^^
File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/fastapi/routing.py", line 377, in app
content = await serialize_response(
^^^^^^^^^^^^^^^^^^^^^^^^^
...<10 lines>...
)
^
File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/fastapi/routing.py", line 215, in serialize_response
raise ResponseValidationError(
...<3 lines>...
)
fastapi.exceptions.ResponseValidationError: 2 validation errors:
{'type': 'list_type', 'loc': ('response', 'logprobs'), 'msg': 'Input should be a valid list', 'input': None}
{'type': 'list_type', 'loc': ('response', 'logprob_token_ids'), 'msg': 'Input should be a valid list', 'input': None}
File "/fsx/qgallouedec/trl/trl/scripts/vllm_serve.py", line 512, in generate
POST /generate/
There was a problem hiding this comment.
or
>>> client.generate(["Hello, AI!", "Tell me a joke"])["logprobs"]
[[[-0.2079554945230484], [-1.6463592052459717], [-0.4071059823036194], [-1.609581470489502], [-0.06224556267261505], [-1.5437123775482178], [-2.3698196411132812], [-0.29048386216163635], [-0.2877020239830017], [-2.652569532394409], [-0.4416201710700989], [-10.628130912780762], [-1.1008961200714111], [-0.37943634390830994], [-0.7756030559539795], [-14.819398880004883]], [[-0.0422004759311676], [-5.340917587280273], [-0.9934616088867188], [-1.7789040803909302], [-0.009760985150933266], [-0.00201974855735898], [-1.4781842764932662e-05], [-5.8412379075889476e-06], [-0.07062072306871414], [-0.0012309125158935785], [0.0], [-0.018214812502264977], [-0.001459129503928125], [-0.5814141035079956], [-0.01416344940662384], [-0.0002592465898487717]]]There was a problem hiding this comment.
I agree, i think it's always list[list[list[float]]]. Or actually, it should be list[list[list[float]]] | None, and vllm_serve should correctly return None when passing logprobs=None, right now it crashes as Quentin noted.
| ) | ||
| # vLLM returns per-token top-k logprobs; keep only the top-1 (sampled token) logprob | ||
| logprobs = [[lp[0] for lp in seq] for seq in logprobs] | ||
| if isinstance(logprobs[0][0], list): |
There was a problem hiding this comment.
I don't "really" like this check since it depends on a lot of nesting. Open to a better way to handle this guard.
What does this PR do?
#5107 introduced a bug in GRPOTrainer whereby it sets logprobs to 0, but then attempts to reduce over the logprobs array when vllm's response structure is dependent on that parameter.
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.