fix(llm): special tokens and leading space (#1831)
This commit is contained in:
parent
08c4ab175e
commit
347be643f7
|
@ -243,12 +243,19 @@ class SagemakerLLM(CustomLLM):
|
||||||
event_stream = resp["Body"]
|
event_stream = resp["Body"]
|
||||||
start_json = b"{"
|
start_json = b"{"
|
||||||
stop_token = "<|endoftext|>"
|
stop_token = "<|endoftext|>"
|
||||||
|
first_token = True
|
||||||
|
|
||||||
for line in LineIterator(event_stream):
|
for line in LineIterator(event_stream):
|
||||||
if line != b"" and start_json in line:
|
if line != b"" and start_json in line:
|
||||||
data = json.loads(line[line.find(start_json) :].decode("utf-8"))
|
data = json.loads(line[line.find(start_json) :].decode("utf-8"))
|
||||||
if data["token"]["text"] != stop_token:
|
special = data["token"]["special"]
|
||||||
|
stop = data["token"]["text"] == stop_token
|
||||||
|
if not special and not stop:
|
||||||
delta = data["token"]["text"]
|
delta = data["token"]["text"]
|
||||||
|
# trim the leading space for the first token if present
|
||||||
|
if first_token:
|
||||||
|
delta = delta.lstrip()
|
||||||
|
first_token = False
|
||||||
text += delta
|
text += delta
|
||||||
yield CompletionResponse(delta=delta, text=text, raw=data)
|
yield CompletionResponse(delta=delta, text=text, raw=data)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue