neuron-distributed for inference
sonic182 opened this issue · 1 comments
sonic182 commented
Hi, I'm trying to make compatible a Clip model using neuron-distributed (because I'm gonna continue with a multimodal after it)
Currently in my notebook, insidea inf2.xlarge ubuntu 22, I have:
from torch import nn
from neuronx_distributed.parallel_layers import parallel_state, utils
from neuronx_distributed.parallel_layers.layers import (
ColumnParallelLinear,
ParallelEmbedding,
RowParallelLinear,
)
from transformers.models.clip import CLIPModel
from transformers.models.clip.modeling_clip import (
CLIPTextTransformer,
CLIPVisionTransformer,
CLIPVisionEmbeddings,
CLIPTextEmbeddings,
CLIPEncoder,
CLIPEncoderLayer,
CLIPAttention,
CLIPMLP
)
from transformers.models.clip.configuration_clip import CLIPTextConfig, CLIPVisionConfig
# DONE
class CLIPAttentionNeuron(CLIPAttention):
def __init__(self, config):
super().__init__(config)
world_size = parallel_state.get_tensor_model_parallel_size()
print(f"world size: {world_size}")
self.k_proj = ColumnParallelLinear(self.embed_dim, self.embed_dim)
self.v_proj = ColumnParallelLinear(self.embed_dim, self.embed_dim)
self.q_proj = ColumnParallelLinear(self.embed_dim, self.embed_dim)
self.out_proj = RowParallelLinear(self.embed_dim, self.embed_dim)
self.num_heads = utils.divide(config.num_attention_heads, world_size)
# DONE
class CLIPMLPNeuron(CLIPMLP):
def __init__(self, config):
super().__init__(config)
self.fc1 = ColumnParallelLinear(config.hidden_size, config.intermediate_size)
self.fc2 = RowParallelLinear(config.intermediate_size, config.hidden_size)
# Done
class CLIPEncoderLayerNeuron(CLIPEncoderLayer):
def __init__(self, config):
super().__init__(config)
self.embed_dim = config.hidden_size
self.self_attn = CLIPAttentionNeuron(config)
self.mlp = CLIPMLPNeuron(config)
# Done
class CLIPEncoderNeuron(CLIPEncoder):
def __init__(self, config):
super().__init__(config)
self.layers = nn.ModuleList([CLIPEncoderLayerNeuron(config) for _ in range(config.num_hidden_layers)])
# Done
class CLIPVisionEmbeddingsNeuron(CLIPVisionEmbeddings):
def __init__(self, config):
super().__init__(config)
self.position_embedding = ParallelEmbedding(self.num_positions, self.embed_dim)
# Done
class CLIPVisionTransformerNeuron(CLIPVisionTransformer):
def __init__(self, config):
super().__init__(config)
self.embeddings = CLIPVisionEmbeddingsNeuron(config)
self.encoder = CLIPEncoderNeuron(config)
# Done
class CLIPTextEmbeddingsNeuron(CLIPTextEmbeddings):
def __init__(self, config):
super().__init__(config)
embed_dim = config.hidden_size
self.token_embedding = ParallelEmbedding(config.vocab_size, embed_dim)
self.position_embedding = ParallelEmbedding(config.max_position_embeddings, embed_dim)
class CLIPTextTransformerNeuron(CLIPTextTransformer):
def __init__(self, config):
super().__init__(config)
self.embeddings = CLIPTextEmbeddingsNeuron(config)
self.encoder = CLIPEncoderNeuron(config)
class CLIPModelNeuron(CLIPModel):
def __init__(self, config):
super().__init__(config)
if not isinstance(config.text_config, CLIPTextConfig):
raise ValueError(
"config.text_config is expected to be of type CLIPTextConfig but is of type"
f" {type(config.text_config)}."
)
if not isinstance(config.vision_config, CLIPVisionConfig):
raise ValueError(
"config.vision_config is expected to be of type CLIPVisionConfig but is of type"
f" {type(config.vision_config)}."
)
text_config = config.text_config
vision_config = config.vision_config
self.projection_dim = config.projection_dim
self.text_embed_dim = text_config.hidden_size
self.vision_embed_dim = vision_config.hidden_size
self.text_model = CLIPTextTransformerNeuron(text_config)
self.vision_model = CLIPVisionTransformerNeuron(vision_config)
self.visual_projection = ColumnParallelLinear(self.vision_embed_dim, self.projection_dim, bias=False)
self.text_projection = ColumnParallelLinear(self.text_embed_dim, self.projection_dim, bias=False)
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
# Initialize weights and apply final processing
self.post_init()
Then, when I try to load the pretrained clip with:
model = CLIPModelNeuron.from_pretrained("openai/clip-vit-base-patch32")
model
I'm getting this error:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[2], [line 1](vscode-notebook-cell:?execution_count=2&line=1)
----> [1](vscode-notebook-cell:?execution_count=2&line=1) model = CLIPModelNeuron.from_pretrained("openai/clip-vit-base-patch32")
[2](vscode-notebook-cell:?execution_count=2&line=2) model
File ~/notebooks/venv/lib/python3.10/site-packages/transformers/modeling_utils.py:3626, in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)
[3620](https://vscode-remote+ssh-002dremote-002baigneuronnotebook.vscode-resource.vscode-cdn.net/home/ubuntu/notebooks/~/notebooks/venv/lib/python3.10/site-packages/transformers/modeling_utils.py:3620) config = cls._autoset_attn_implementation(
[3621](https://vscode-remote+ssh-002dremote-002baigneuronnotebook.vscode-resource.vscode-cdn.net/home/ubuntu/notebooks/~/notebooks/venv/lib/python3.10/site-packages/transformers/modeling_utils.py:3621) config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map
[3622](https://vscode-remote+ssh-002dremote-002baigneuronnotebook.vscode-resource.vscode-cdn.net/home/ubuntu/notebooks/~/notebooks/venv/lib/python3.10/site-packages/transformers/modeling_utils.py:3622) )
[3624](https://vscode-remote+ssh-002dremote-002baigneuronnotebook.vscode-resource.vscode-cdn.net/home/ubuntu/notebooks/~/notebooks/venv/lib/python3.10/site-packages/transformers/modeling_utils.py:3624) with ContextManagers(init_contexts):
[3625](https://vscode-remote+ssh-002dremote-002baigneuronnotebook.vscode-resource.vscode-cdn.net/home/ubuntu/notebooks/~/notebooks/venv/lib/python3.10/site-packages/transformers/modeling_utils.py:3625) # Let's make sure we don't run the init function of buffer modules
-> [3626](https://vscode-remote+ssh-002dremote-002baigneuronnotebook.vscode-resource.vscode-cdn.net/home/ubuntu/notebooks/~/notebooks/venv/lib/python3.10/site-packages/transformers/modeling_utils.py:3626) model = cls(config, *model_args, **model_kwargs)
[3628](https://vscode-remote+ssh-002dremote-002baigneuronnotebook.vscode-resource.vscode-cdn.net/home/ubuntu/notebooks/~/notebooks/venv/lib/python3.10/site-packages/transformers/modeling_utils.py:3628) # make sure we use the model's config since the __init__ call might have copied it
[3629](https://vscode-remote+ssh-002dremote-002baigneuronnotebook.vscode-resource.vscode-cdn.net/home/ubuntu/notebooks/~/notebooks/venv/lib/python3.10/site-packages/transformers/modeling_utils.py:3629) config = model.config
Cell In[1], [line 107](vscode-notebook-cell:?execution_count=1&line=107)
[104](vscode-notebook-cell:?execution_count=1&line=104) self.text_embed_dim = text_config.hidden_size
[105](vscode-notebook-cell:?execution_count=1&line=105) self.vision_embed_dim = vision_config.hidden_size
--> [107](vscode-notebook-cell:?execution_count=1&line=107) self.text_model = CLIPTextTransformerNeuron(text_config)
[108](vscode-notebook-cell:?execution_count=1&line=108) self.vision_model = CLIPVisionTransformerNeuron(vision_config)
[110](vscode-notebook-cell:?execution_count=1&line=110) self.visual_projection = ColumnParallelLinear(self.vision_embed_dim, self.projection_dim, bias=False)
Cell In[1], [line 81](vscode-notebook-cell:?execution_count=1&line=81)
[79](vscode-notebook-cell:?execution_count=1&line=79) def __init__(self, config):
[80](vscode-notebook-cell:?execution_count=1&line=80) super().__init__(config)
---> [81](vscode-notebook-cell:?execution_count=1&line=81) self.embeddings = CLIPTextEmbeddingsNeuron(config)
[82](vscode-notebook-cell:?execution_count=1&line=82) self.encoder = CLIPEncoderNeuron(config)
Cell In[1], [line 75](vscode-notebook-cell:?execution_count=1&line=75)
[73](vscode-notebook-cell:?execution_count=1&line=73) super().__init__(config)
[74](vscode-notebook-cell:?execution_count=1&line=74) embed_dim = config.hidden_size
---> [75](vscode-notebook-cell:?execution_count=1&line=75) self.token_embedding = ParallelEmbedding(config.vocab_size, embed_dim)
[76](vscode-notebook-cell:?execution_count=1&line=76) self.position_embedding = ParallelEmbedding(config.max_position_embeddings, embed_dim)
File ~/notebooks/venv/lib/python3.10/site-packages/neuronx_distributed/parallel_layers/layers.py:129, in ParallelEmbedding.__init__(self, num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, init_method, device, dtype)
[127](https://vscode-remote+ssh-002dremote-002baigneuronnotebook.vscode-resource.vscode-cdn.net/home/ubuntu/notebooks/~/notebooks/venv/lib/python3.10/site-packages/neuronx_distributed/parallel_layers/layers.py:127) self.scale_grad_by_freq = scale_grad_by_freq
[128](https://vscode-remote+ssh-002dremote-002baigneuronnotebook.vscode-resource.vscode-cdn.net/home/ubuntu/notebooks/~/notebooks/venv/lib/python3.10/site-packages/neuronx_distributed/parallel_layers/layers.py:128) self.sparse = sparse
--> [129](https://vscode-remote+ssh-002dremote-002baigneuronnotebook.vscode-resource.vscode-cdn.net/home/ubuntu/notebooks/~/notebooks/venv/lib/python3.10/site-packages/neuronx_distributed/parallel_layers/layers.py:129) self.tensor_model_parallel_size = get_tensor_model_parallel_size()
[130](https://vscode-remote+ssh-002dremote-002baigneuronnotebook.vscode-resource.vscode-cdn.net/home/ubuntu/notebooks/~/notebooks/venv/lib/python3.10/site-packages/neuronx_distributed/parallel_layers/layers.py:130) # Divide the weight matrix along the vocabulary dimension.
[131](https://vscode-remote+ssh-002dremote-002baigneuronnotebook.vscode-resource.vscode-cdn.net/home/ubuntu/notebooks/~/notebooks/venv/lib/python3.10/site-packages/neuronx_distributed/parallel_layers/layers.py:131) (
[132](https://vscode-remote+ssh-002dremote-002baigneuronnotebook.vscode-resource.vscode-cdn.net/home/ubuntu/notebooks/~/notebooks/venv/lib/python3.10/site-packages/neuronx_distributed/parallel_layers/layers.py:132) self.start_index,
[133](https://vscode-remote+ssh-002dremote-002baigneuronnotebook.vscode-resource.vscode-cdn.net/home/ubuntu/notebooks/~/notebooks/venv/lib/python3.10/site-packages/neuronx_distributed/parallel_layers/layers.py:133) self.end_index,
(...)
[137](https://vscode-remote+ssh-002dremote-002baigneuronnotebook.vscode-resource.vscode-cdn.net/home/ubuntu/notebooks/~/notebooks/venv/lib/python3.10/site-packages/neuronx_distributed/parallel_layers/layers.py:137) self.tensor_model_parallel_size,
[138](https://vscode-remote+ssh-002dremote-002baigneuronnotebook.vscode-resource.vscode-cdn.net/home/ubuntu/notebooks/~/notebooks/venv/lib/python3.10/site-packages/neuronx_distributed/parallel_layers/layers.py:138) )
File ~/notebooks/venv/lib/python3.10/site-packages/neuronx_distributed/parallel_layers/parallel_state.py:188, in get_tensor_model_parallel_size()
[186](https://vscode-remote+ssh-002dremote-002baigneuronnotebook.vscode-resource.vscode-cdn.net/home/ubuntu/notebooks/~/notebooks/venv/lib/python3.10/site-packages/neuronx_distributed/parallel_layers/parallel_state.py:186) if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
[187](https://vscode-remote+ssh-002dremote-002baigneuronnotebook.vscode-resource.vscode-cdn.net/home/ubuntu/notebooks/~/notebooks/venv/lib/python3.10/site-packages/neuronx_distributed/parallel_layers/parallel_state.py:187) return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
--> [188](https://vscode-remote+ssh-002dremote-002baigneuronnotebook.vscode-resource.vscode-cdn.net/home/ubuntu/notebooks/~/notebooks/venv/lib/python3.10/site-packages/neuronx_distributed/parallel_layers/parallel_state.py:188) return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
File ~/notebooks/venv/lib/python3.10/site-packages/neuronx_distributed/parallel_layers/parallel_state.py:173, in get_tensor_model_parallel_group(as_list)
[171](https://vscode-remote+ssh-002dremote-002baigneuronnotebook.vscode-resource.vscode-cdn.net/home/ubuntu/notebooks/~/notebooks/venv/lib/python3.10/site-packages/neuronx_distributed/parallel_layers/parallel_state.py:171) def get_tensor_model_parallel_group(as_list=False):
[172](https://vscode-remote+ssh-002dremote-002baigneuronnotebook.vscode-resource.vscode-cdn.net/home/ubuntu/notebooks/~/notebooks/venv/lib/python3.10/site-packages/neuronx_distributed/parallel_layers/parallel_state.py:172) """Get the tensor model parallel group the caller rank belongs to."""
--> [173](https://vscode-remote+ssh-002dremote-002baigneuronnotebook.vscode-resource.vscode-cdn.net/home/ubuntu/notebooks/~/notebooks/venv/lib/python3.10/site-packages/neuronx_distributed/parallel_layers/parallel_state.py:173) assert _TENSOR_MODEL_PARALLEL_GROUP is not None, "intra_layer_model parallel group is not initialized"
[174](https://vscode-remote+ssh-002dremote-002baigneuronnotebook.vscode-resource.vscode-cdn.net/home/ubuntu/notebooks/~/notebooks/venv/lib/python3.10/site-packages/neuronx_distributed/parallel_layers/parallel_state.py:174) return _TENSOR_MODEL_PARALLEL_GROUP._mesh if as_list else _TENSOR_MODEL_PARALLEL_GROUP
AssertionError: intra_layer_model parallel group is not initialized
The thing is ... I need a distributed process for inference?
And if so, how can I start it in a inf2 or trn* instance? I'm bit newbie with torch.distributed
sonic182 commented
closed in favor of aws-neuron/neuronx-distributed#23