-
Notifications
You must be signed in to change notification settings - Fork 55
Properly support batched/non-batched with vllm/llama.cpp #77
Conversation
and other streamlining
src/instructlab/sdg/llmblock.py
Outdated
| parsed_outputs = self._parse(output) | ||
| # pylint: disable=consider-using-generator | ||
|
|
||
| max_length = max([len(value) for value in parsed_outputs.values()]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Github lint is suggesting to use max(len(value) for value in parsed_outputs.values())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic still needs to be cleaned up anyhow I think, it's not doing what it was intended to
| return super()._validate(prompt_template, input_dict) | ||
|
|
||
|
|
||
| def server_supports_batched(client, model_id: str) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be nitpick but we can use server_supports_batching instead of server_supports_batched?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think batched is better.. since it's referring to the inputs. Even without batched inputs it might do batching internally.
| for prompt in prompts: | ||
| for _ in range(n): | ||
| response = self.client.completions.create( | ||
| prompt=prompt, **generate_args | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could rewrite this as
responses = [
self.client.completions.create(prompt=prompt, **generate_args)
for prompt in prompts
for _ in range(n)
]
wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes but then we would require an additional loop anyhow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah
|
Thanks @npalaska I addressed most of those comments. |
markmc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, looks like a great direction
At least resolve the #TODO remove sample from samples thing
|
|
||
| # Whether the LLM server supports a list of input prompts | ||
| # and supports the n parameter to generate n outputs per input | ||
| self.server_supports_batched = server_supports_batched(client, model_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The FlowParams in #64 would give us a place to do this once rather than for every LLMBlock, but that can be fixed up later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See PipelineContext in #86 now
| "model": self.model, | ||
| "temperature": 0, | ||
| "max_tokens": 12000, | ||
| #"seed": 12345, TBD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Delete? Or add an explanation to the comment
| ) | ||
| return [choice.text.strip() for choice in response.choices] | ||
|
|
||
| n = gen_kwargs.get("n", 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would have imagined doing this in reverse - including "num_instructions_to_generate" in the block config and adding 'n' to gen_kwargs if batching was supported. No biggie though
| if not self._validate(self.prompt_template, sample): | ||
| return None | ||
| logger.warning("Sample failed validation") #TODO add details | ||
| #TODO remove sample from samples |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm. Should this be in a separate PR. If in this PR, the TODO should be resolved?
| outputs = self._generate(samples, **gen_kwargs) | ||
| logger.debug("Generated outputs: %s", outputs) | ||
|
|
||
| num_parallel_samples = gen_kwargs.get("n", 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, and here's a reason to make num_parallel_samples part of the block config ... and add 'n' to gen_kwargs based on that
| supported = len(response.choices) == 6 | ||
| except openai.InternalServerError: | ||
| supported = False | ||
| client.server_supports_batched = supported |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I understand that you want to cache this ... but I don't like setting a new attribute on a class we don't own
I guess this could be removed with a move to FlowParams
|
Also, please squash those fixup commits as per instructlab/dev-docs#110 |
|
This pull request has merge conflicts that must be resolved before it can be |
As per instructlab/sdg#77 Signed-off-by: Mark McLoughlin <[email protected]>
|
Closing in favor of #105 |
This is based on @npalaska's PR #58.
With these changes we will auto-detect whether the server supports batched inputs and if not will send them sequentially.