diff --git a/private_gpt/components/llm/custom/sagemaker.py b/private_gpt/components/llm/custom/sagemaker.py index 7c46111..e20f539 100644 --- a/private_gpt/components/llm/custom/sagemaker.py +++ b/private_gpt/components/llm/custom/sagemaker.py @@ -243,12 +243,19 @@ class SagemakerLLM(CustomLLM): event_stream = resp["Body"] start_json = b"{" stop_token = "<|endoftext|>" + first_token = True for line in LineIterator(event_stream): if line != b"" and start_json in line: data = json.loads(line[line.find(start_json) :].decode("utf-8")) - if data["token"]["text"] != stop_token: + special = data["token"]["special"] + stop = data["token"]["text"] == stop_token + if not special and not stop: delta = data["token"]["text"] + # trim the leading space for the first token if present + if first_token: + delta = delta.lstrip() + first_token = False text += delta yield CompletionResponse(delta=delta, text=text, raw=data)