segment-any-text/wtpsplit

SaT is slow

Closed this issue · 2 comments

Hi, I tried using SaT as a drop in for WtP (wtp-canine-s-1l-no-adapters).

However no matter which variant between 1l and 3l I try, it always takes nearly a second to run inference (vs 0.013s (cpu) and 0.005s (gpu) with wtp). There is no difference between CPU and GPU in the SaT runtime for me.

System: docker
FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04
Python 3.10.12

Pip

adapters==0.2.1
aiohttp==3.9.5
aioice==0.9.0
aiortc==1.9.0
aiosignal==1.3.1
annotated-types==0.7.0
anyio==4.4.0
asgiref==3.8.1
async-timeout==4.0.3
attrs==23.2.0
av==12.1.0
backoff==2.2.1
bcrypt==4.1.3
beautifulsoup4==4.12.3
build==1.2.1
cached-property==1.5.2
cachetools==5.3.3
certifi==2024.6.2
cffi==1.16.0
charset-normalizer==3.3.2
chroma-hnswlib==0.7.3
chromadb==0.5.3
click==8.1.7
coloredlogs==15.0.1
cryptography==42.0.8
dataclasses-json==0.6.7
Deprecated==1.2.14
dirtyjson==1.0.8
distro==1.9.0
dnspython==2.6.1
docopt==0.6.2
email_validator==2.2.0
exceptiongroup==1.2.1
fastapi==0.111.0
fastapi-cli==0.0.4
filelock==3.15.4
flatbuffers==24.3.25
frozenlist==1.4.1
fsspec==2024.6.1
google-auth==2.30.0
google-crc32c==1.5.0
googleapis-common-protos==1.63.2
greenlet==3.0.3
grpcio==1.64.1
h11==0.14.0
httpcore==1.0.5
httptools==0.6.1
httpx==0.27.0
huggingface-hub==0.23.4
humanfriendly==10.0
idna==3.7
ifaddr==0.2.0
ijson==3.3.0
importlib_metadata==7.1.0
importlib_resources==6.4.0
Jinja2==3.1.4
joblib==1.4.2
kubernetes==30.1.0
litellm==1.40.29
llama-cloud==0.0.6
llama-index-core==0.10.50.post1
llama-index-embeddings-huggingface==0.2.1
llama-index-llms-litellm==0.1.4
llama-index-llms-openai==0.1.22
llama-index-readers-file==0.1.25
llama-index-readers-smart-pdf-loader==0.1.4
llama-index-vector-stores-chroma==0.1.8
llmsherpa==0.1.4
markdown-it-py==3.0.0
MarkupSafe==2.1.5
marshmallow==3.21.3
mdurl==0.1.2
minijinja==2.0.1
mmh3==4.1.0
monotonic==1.6
mosestokenizer==1.2.1
mpmath==1.3.0
multidict==6.0.5
mypy-extensions==1.0.0
nest-asyncio==1.6.0
networkx==3.3
nltk==3.8.1
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.5.40
nvidia-nvtx-cu12==12.1.105
oauthlib==3.2.2
onnxruntime==1.18.1
openai==1.35.7
opencv-contrib-python==4.10.0.82
openfile==0.0.7
opentelemetry-api==1.25.0
opentelemetry-exporter-otlp-proto-common==1.25.0
opentelemetry-exporter-otlp-proto-grpc==1.25.0
opentelemetry-instrumentation==0.46b0
opentelemetry-instrumentation-asgi==0.46b0
opentelemetry-instrumentation-fastapi==0.46b0
opentelemetry-proto==1.25.0
opentelemetry-sdk==1.25.0
opentelemetry-semantic-conventions==0.46b0
opentelemetry-util-http==0.46b0
orjson==3.10.5
overrides==7.7.0
packaging==24.1
pandas==2.2.2
pillow==10.3.0
posthog==3.5.0
protobuf==4.25.3
pyasn1==0.6.0
pyasn1_modules==0.4.0
pycparser==2.22
pydantic==2.7.4
pydantic_core==2.18.4
pyee==11.1.0
Pygments==2.18.0
pylibsrtp==0.10.0
pyOpenSSL==24.1.0
pypdf==4.2.0
PyPika==0.48.9
pyproject_hooks==1.1.0
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-multipart==0.0.9
pytz==2024.1
PyYAML==6.0.1
regex==2024.5.15
requests==2.32.3
requests-oauthlib==2.0.0
rich==13.7.1
rsa==4.9
safetensors==0.4.3
scikit-learn==1.5.0
scipy==1.14.0
sentence-transformers==2.7.0
shellingham==1.5.4
six==1.16.0
skops==0.9.0
sniffio==1.3.1
soupsieve==2.5
SQLAlchemy==2.0.31
starlette==0.37.2
striprtf==0.0.26
sympy==1.12.1
tabulate==0.9.0
tenacity==8.4.2
threadpoolctl==3.5.0
tiktoken==0.7.0
tokenizers==0.15.2
tomli==2.0.1
toolwrapper==2.1.0
torch==2.3.1
tqdm==4.66.4
transformers==4.39.3
triton==2.3.1
typer==0.12.3
typing-inspect==0.9.0
typing_extensions==4.12.2
tzdata==2024.1
uctools==1.3.0
ujson==5.10.0
urllib3==2.2.2
uvicorn==0.30.1
uvloop==0.19.0
watchfiles==0.22.0
websocket-client==1.8.0
websockets==12.0
wrapt==1.16.0
wtpsplit==2.0.3
yarl==1.9.4
zipp==3.19.2

Hi, thanks for catching this! There was a small issue with the tokenizer. We fixed it with wtpsplit==2.0.4; please upgrade.

With this, I get (both 1L):

SaT:

%timeit sat.split(SENTENCE * 100)

801 ms ± 351 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

SaT + GPU:

%timeit sat.split(SENTENCE * 100)

65.5 ms ± 1.55 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

WtP:

%timeit wtp.split(SENTENCE * 100)

6.08 s ± 1.49 s per loop (mean ± std. dev. of 7 runs, 1 loop each)

WtP + GPU:

%timeit wtp.split(SENTENCE * 100)

370 ms ± 9.49 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

(Note: For very short sequences and small models, it may still be that WtP is slightly faster. But you should absolutely not use WtP with short sequences regardless since others have reported problematic inconsistencies and we also show its poor performance in our paper.)

Thank you very much! It works a lot faster now than WtP did - even for very short texts. And the separation is much more natural.