Skip to content

Commit 142520a

Browse files
committed
fix container env + Neuron related changes
1 parent 56c15d8 commit 142520a

File tree

5 files changed

+78
-29
lines changed

5 files changed

+78
-29
lines changed

Dockerfile-neuron

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,9 @@ ARG NEURONX_COLLECTIVES_LIB_VERSION=2.28.27.0-bc30ece58
9090
ARG NEURONX_RUNTIME_LIB_VERSION=2.28.23.0-dd5879008
9191
ARG NEURONX_TOOLS_VERSION=2.26.14.0
9292

93-
ARG NEURONX_CC_VERSION=2.21.18209.0+043b1bf7
94-
ARG NEURONX_FRAMEWORK_VERSION=2.8.0.2.10.13553+1e4dd6ca
93+
ARG NEURONX_CC_VERSION=2.21.33363.0+82129205
94+
ARG NEURONX_FRAMEWORK_VERSION=2.8.0.2.10.16998+e9bf8a50
9595
ARG NEURONX_DISTRIBUTED_VERSION=0.15.22404+1f27bddf
96-
ARG NEURONX_DISTRIBUTED_INFERENCE_VERSION=0.6.10598+a59fdc00
9796

9897
RUN apt-get update \
9998
&& apt-get upgrade -y \
@@ -137,13 +136,13 @@ RUN apt-get update \
137136
&& apt-get clean
138137

139138
ENV PATH="/opt/aws/neuron/bin:${PATH}"
140-
ENV NEURON_RT_VISIBLE_CORES=ALL
141139

142140
RUN pip install --index-url https://pip.repos.neuron.amazonaws.com \
143141
--extra-index-url https://pypi.org/simple \
144142
--trusted-host pip.repos.neuron.amazonaws.com \
145143
neuronx-cc==$NEURONX_CC_VERSION \
146144
torch-neuronx==$NEURONX_FRAMEWORK_VERSION \
145+
torchvision \
147146
neuronx_distributed==$NEURONX_DISTRIBUTED_VERSION \
148147
&& rm -rf ~/.cache/pip/*
149148

backends/python/server/text_embeddings_server/models/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from text_embeddings_server.models.jinaBert_model import FlashJinaBert
1515
from text_embeddings_server.models.flash_mistral import FlashMistral
1616
from text_embeddings_server.models.flash_qwen3 import FlashQwen3
17-
from text_embeddings_server.models.neuron_models import NeuronSentenceTransformers
17+
from text_embeddings_server.models.neuron_models import NeuronSentenceTransformersModel
1818

1919
from text_embeddings_server.utils.device import get_device, use_ipex, is_neuron
2020

@@ -80,7 +80,7 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
8080
# Neuron cases
8181
if is_neuron():
8282
if config.model_type == "bert":
83-
return create_model(NeuronSentenceTransformers, model_path)
83+
return create_model(NeuronSentenceTransformersModel, model_path, device, datatype)
8484

8585
if (
8686
hasattr(config, "auto_map")

backends/python/server/text_embeddings_server/models/neuron_models.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from pathlib import Path
55
from typing import Type, List
6-
from optimum.neuron import NeuronModelForSentenceTransformers
6+
from optimum.neuron import NeuronSentenceTransformers
77
from opentelemetry import trace
88

99
from text_embeddings_server.models import Model
@@ -12,14 +12,14 @@
1212
tracer = trace.get_tracer(__name__)
1313

1414

15-
class NeuronSentenceTransformers(Model):
15+
class NeuronSentenceTransformersModel(Model):
1616
def __init__(
1717
self,
1818
model_path: Path,
1919
device: torch.device,
2020
dtype: torch.dtype,
2121
):
22-
model = NeuronModelForSentenceTransformers.from_pretrained(model_path)
22+
model = NeuronSentenceTransformers.from_pretrained(model_path)
2323

2424
self.hidden_size = model.config.hidden_size
2525
position_offset = 0
@@ -42,7 +42,7 @@ def __init__(
4242
is not None
4343
)
4444

45-
super(NeuronSentenceTransformers, self).__init__(
45+
super(NeuronSentenceTransformersModel, self).__init__(
4646
model=model, dtype=dtype, device=device
4747
)
4848

@@ -52,16 +52,20 @@ def batch_type(self) -> Type[PaddedBatch]:
5252

5353
@tracer.start_as_current_span("embed")
5454
def embed(self, batch: PaddedBatch) -> List[Embedding]:
55-
pass
56-
57-
@tracer.start_as_current_span("predict")
58-
def predict(self, batch: PaddedBatch) -> List[Score]:
5955
kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask}
6056
if self.has_token_type_ids:
6157
kwargs["token_type_ids"] = batch.token_type_ids
62-
if self.has_position_ids:
63-
kwargs["position_ids"] = batch.position_ids
58+
output = self.model(**kwargs)
59+
60+
sentence_embedding = output["sentence_embedding"]
6461

65-
output = self.model(**kwargs, return_dict=True)
66-
all_scores = output.logits.tolist()
67-
return [Score(values=scores) for scores in all_scores]
62+
return [
63+
Embedding(
64+
values=sentence_embedding[i * self.hidden_size : (i + 1) * self.hidden_size]
65+
)
66+
for i in range(len(batch))
67+
]
68+
69+
@tracer.start_as_current_span("predict")
70+
def predict(self, batch: PaddedBatch) -> List[Score]:
71+
pass

backends/src/lib.rs

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,15 @@ fn is_hpu() -> bool {
6767
}
6868
}
6969

70+
fn is_neuron() -> bool {
71+
match Command::new("neuron-ls")
72+
.output()
73+
{
74+
Ok(output) => output.status.success(),
75+
Err(_) => false,
76+
}
77+
}
78+
7079
#[derive(Debug, Clone)]
7180
pub struct Backend {
7281
/// Channel to communicate with the background thread
@@ -409,16 +418,39 @@ async fn init_backend(
409418
if let Some(api_repo) = api_repo.as_ref() {
410419
if cfg!(feature = "python") || cfg!(feature = "candle") {
411420
let start = std::time::Instant::now();
412-
if download_safetensors(api_repo).await.is_err() {
413-
tracing::warn!("safetensors weights not found. Using `pytorch_model.bin` instead. Model loading will be significantly slower.");
414-
tracing::info!("Downloading `pytorch_model.bin`");
415-
api_repo
416-
.get("pytorch_model.bin")
421+
if is_neuron() {
422+
tracing::info!("Downloading `model.neuron`");
423+
let model_files = download_neuron(api_repo)
417424
.await
418425
.map_err(|err| BackendError::WeightsNotFound(err.to_string()))?;
419-
}
420426

421-
tracing::info!("Model weights downloaded in {:?}", start.elapsed());
427+
if model_files.is_empty() {
428+
tracing::error!(
429+
"Neuron model files not found in the repository. \
430+
You can easily compile your model to neuron format following the guide: \
431+
https://huggingface.co/docs/optimum-neuron/en/model_doc/sentence_transformers/overview "
432+
);
433+
return Err(BackendError::WeightsNotFound(
434+
"No Neuron model files found".into(),
435+
));
436+
}
437+
438+
tracing::info!("Neuron model downloaded in {:?}", start.elapsed());
439+
} else {
440+
if download_safetensors(api_repo).await.is_err() {
441+
tracing::warn!(
442+
"safetensors weights not found. Using `pytorch_model.bin` instead. \
443+
Model loading will be significantly slower."
444+
);
445+
tracing::info!("Downloading `pytorch_model.bin`");
446+
api_repo
447+
.get("pytorch_model.bin")
448+
.await
449+
.map_err(|err| BackendError::WeightsNotFound(err.to_string()))?;
450+
}
451+
452+
tracing::info!("Model weights downloaded in {:?}", start.elapsed());
453+
}
422454
}
423455
}
424456

@@ -655,6 +687,20 @@ async fn download_onnx(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> {
655687
Ok(model_files)
656688
}
657689

690+
async fn download_neuron(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> {
691+
let mut model_files: Vec<PathBuf> = Vec::new();
692+
693+
tracing::info!("Downloading `model.neuron`");
694+
match api.get("model.neuron").await {
695+
Ok(p) => model_files.push(p),
696+
Err(err) => {
697+
tracing::warn!("Could not download `model.neuron`: {err}");
698+
}
699+
};
700+
701+
Ok(model_files)
702+
}
703+
658704
#[cfg(feature = "candle")]
659705
#[derive(Debug, Clone, Deserialize, PartialEq)]
660706
enum ModuleType {

docs/source/en/ aws_neuron.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,16 @@ To build a container optimized for AWS Neuron devices, run the following command
2222
```shell
2323
platform="neuron"
2424

25-
docker build . -f Dockerfile-neuron -t tei_neuron
25+
docker build . -f Dockerfile-neuron -t tei-neuron:main
2626
```
2727

2828
### Deploy Docker Container
2929

3030
To deploy your model on an AWS Trainium or Inferentia instance, use the following command:
3131

3232
```shell
33-
model='Qwen/Qwen3-Embedding-0.6B'
33+
model='optimum/bge-base-en-v1.5-neuronx'
3434
volume=$PWD/data
3535

36-
docker run -p 8080:80 -v $volume:/data tei_neuron --model-id $model
36+
docker run -p 8080:80 -v $volume:/data tei-neuron:main --model-id $model --dtype float32
3737
```

0 commit comments

Comments
 (0)