fastnlp/fastNLP

Import FastNLP fails for versions 1.0.1

davidleejy opened this issue · 2 comments

Describe the bug
Import FastNLP fails for versions 1.0.1

To Reproduce
Steps to reproduce the behavior:
Python 3.11, $ pip install FastNLP>=1.0.0alpha

import FastNLP

Error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[1], line 13
---> 13 from fastNLP import DataSet

File ~/condaenvs/bbt-hf425/lib/python3.11/site-packages/fastNLP/__init__.py:3
      2 from fastNLP.envs import *
----> 3 from fastNLP.core import *
      5 __version__ = '1.0.1'

File ~/condaenvs/bbt-hf425/lib/python3.11/site-packages/fastNLP/core/__init__.py:114
      1 __all__ = [
      2     # callbacks
      3     'Callback',
   (...)
    112     'Vocabulary'
    113 ]
--> 114 from .callbacks import *
    115 from .collators import *
    116 from .controllers import *

File ~/condaenvs/bbt-hf425/lib/python3.11/site-packages/fastNLP/core/callbacks/__init__.py:41
     39 from .load_best_model_callback import LoadBestModelCallback
     40 from .early_stop_callback import EarlyStopCallback
---> 41 from .torch_callbacks import *
     42 from .more_evaluate_callback import MoreEvaluateCallback
     43 from .has_monitor_callback import ResultsMonitor, HasMonitorCallback

File ~/condaenvs/bbt-hf425/lib/python3.11/site-packages/fastNLP/core/callbacks/torch_callbacks/__init__.py:8
      1 __all__ = [
      2     'TorchWarmupCallback',
      3     'TorchGradClipCallback'
      4 ]
      7 from .torch_lr_sched_callback import TorchWarmupCallback
----> 8 from .torch_grad_clip_callback import TorchGradClipCallback

File ~/condaenvs/bbt-hf425/lib/python3.11/site-packages/fastNLP/core/callbacks/torch_callbacks/torch_grad_clip_callback.py:6
      4 from typing import Union, List
      5 from ..callback import Callback
----> 6 from ...drivers.torch_driver.fairscale import FairScaleDriver
      7 from ...drivers.torch_driver import TorchDriver
      8 from fastNLP.envs.imports import _NEED_IMPORT_FAIRSCALE

File ~/condaenvs/bbt-hf425/lib/python3.11/site-packages/fastNLP/core/drivers/__init__.py:30
      1 __all__ = [
      2     'Driver',
      3     'TorchDriver',
   (...)
     27     'optimizer_state_to_device'
     28 ]
---> 30 from .torch_driver import TorchDriver, TorchSingleDriver, TorchDDPDriver, DeepSpeedDriver, FairScaleDriver, \
     31     TorchFSDPDriver, torch_seed_everything, optimizer_state_to_device
     32 from .jittor_driver import JittorDriver, JittorMPIDriver, JittorSingleDriver
     33 from .paddle_driver import PaddleDriver, PaddleFleetDriver, PaddleSingleDriver, paddle_seed_everything

File ~/condaenvs/bbt-hf425/lib/python3.11/site-packages/fastNLP/core/drivers/torch_driver/__init__.py:18
     16 from .torch_driver import TorchDriver
     17 from .deepspeed import DeepSpeedDriver
---> 18 from .torch_fsdp import TorchFSDPDriver
     19 from .utils import torch_seed_everything, optimizer_state_to_device

File ~/condaenvs/bbt-hf425/lib/python3.11/site-packages/fastNLP/core/drivers/torch_driver/torch_fsdp.py:7
      4 from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_12, _NEED_IMPORT_TORCH
      6 if _TORCH_GREATER_EQUAL_1_12:
----> 7     from torch.distributed.fsdp import FullyShardedDataParallel, StateDictType, FullStateDictConfig, OptimStateKeyType
      9 if _NEED_IMPORT_TORCH:
     10     import torch

File ~/condaenvs/bbt-hf425/lib/python3.11/site-packages/torch/distributed/fsdp/__init__.py:1
----> 1 from .flat_param import FlatParameter
      2 from .fully_sharded_data_parallel import (
      3     BackwardPrefetch,
      4     CPUOffload,
   (...)
     11     StateDictType,
     12 )
     13 from .wrap import ParamExecOrderWrapPolicy

File ~/condaenvs/bbt-hf425/lib/python3.11/site-packages/torch/distributed/fsdp/flat_param.py:26
     23 import torch.nn.functional as F
     24 from torch import Tensor
---> 26 from ._fsdp_extensions import _ext_post_unflatten_transform, _ext_pre_flatten_transform
     27 from ._utils import _alloc_storage, _free_storage, _set_fsdp_flattened, p_assert
     29 __all__ = [
     30     "FlatParameter",
     31     "FlatParamHandle",
   (...)
     37     "HandleTrainingState",
     38 ]

File ~/condaenvs/bbt-hf425/lib/python3.11/site-packages/torch/distributed/fsdp/_fsdp_extensions.py:7
      4 import torch
      5 import torch.distributed as dist
----> 7 from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
     10 class FSDPExtensions(ABC):
     11     """
     12     This enables some customizable hooks to enable composability with tensor
     13     parallelism. To activate these hooks, use :func:`_set_fsdp_extensions` to
     14     set a custom :class:`FSDPExtensions` that implements the hooks.
     15     """

File ~/condaenvs/bbt-hf425/lib/python3.11/site-packages/torch/distributed/fsdp/_shard_utils.py:10
      8 import torch.nn.functional as F
      9 from torch.distributed import distributed_c10d
---> 10 from torch.distributed._shard.sharded_tensor import (
     11     Shard,
     12     ShardedTensor,
     13     ShardedTensorMetadata,
     14     TensorProperties,
     15 )
     16 from torch.distributed._shard.sharding_spec import (
     17     ChunkShardingSpec,
     18     EnumerableShardingSpec,
     19     ShardingSpec,
     20     ShardMetadata,
     21 )
     24 def _sharding_spec_to_offsets(
     25     sharding_spec: ShardingSpec, tensor_numel: int, world_size: int
     26 ) -> List[int]:

File ~/condaenvs/bbt-hf425/lib/python3.11/site-packages/torch/distributed/_shard/__init__.py:1
----> 1 from .api import (
      2     _replicate_tensor,
      3     _shard_tensor,
      4     load_with_process_group,
      5     shard_module,
      6     shard_parameter,
      7 )

File ~/condaenvs/bbt-hf425/lib/python3.11/site-packages/torch/distributed/_shard/api.py:6
      4 import torch.nn as nn
      5 from torch.distributed import distributed_c10d
----> 6 from torch.distributed._shard.sharded_tensor import (
      7     ShardedTensor,
      8     _PartialTensor
      9 )
     10 from .replicated_tensor import ReplicatedTensor
     11 from .sharding_spec import (
     12     ShardingSpec,
     13     ChunkShardingSpec
     14 )

File ~/condaenvs/bbt-hf425/lib/python3.11/site-packages/torch/distributed/_shard/sharded_tensor/__init__.py:8
      5 from typing import List
      7 import torch
----> 8 import torch.distributed._shard.sharding_spec as shard_spec
      9 from torch.distributed._shard.partial_tensor import _PartialTensor
     11 from .api import (
     12     _CUSTOM_SHARDED_OPS,
     13     _SHARDED_OPS,
   (...)
     18     TensorProperties,
     19 )

File ~/condaenvs/bbt-hf425/lib/python3.11/site-packages/torch/distributed/_shard/sharding_spec/__init__.py:1
----> 1 from .api import (
      2     DevicePlacementSpec,
      3     EnumerableShardingSpec,
      4     PlacementSpec,
      5     ShardingSpec,
      6     _infer_sharding_spec_from_shards_metadata,
      7 )
      8 from .chunk_sharding_spec import (
      9     ChunkShardingSpec,
     10 )
     12 from torch.distributed._shard.metadata import ShardMetadata

File ~/condaenvs/bbt-hf425/lib/python3.11/site-packages/torch/distributed/_shard/sharding_spec/api.py:16
      8 from ._internals import (
      9     check_tensor,
     10     get_chunked_dim_size,
     11     get_split_size,
     12     validate_non_overlapping_shards_metadata
     13 )
     14 from torch.distributed._shard.metadata import ShardMetadata
---> 16 import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta
     17 from torch.distributed._shard.op_registry_utils import _decorator_func
     19 if TYPE_CHECKING:
     20     # Only include ShardedTensor when do type checking, exclude it
     21     # from run-time to resolve circular dependency.

File ~/condaenvs/bbt-hf425/lib/python3.11/site-packages/torch/distributed/_shard/sharded_tensor/metadata.py:70
     61     @staticmethod
     62     def create_from_tensor(tensor: torch.Tensor) -> "TensorProperties":
     63         return TensorProperties(
     64             dtype=tensor.dtype,
     65             layout=tensor.layout,
   (...)
     68             pin_memory=tensor.is_pinned()
     69         )
---> 70 @dataclass
     71 class ShardedTensorMetadata(object):
     72     """
     73     Represents metadata for :class:`ShardedTensor`
     74     """
     76     # Metadata about each shard of the Tensor

File ~/condaenvs/bbt-hf425/lib/python3.11/dataclasses.py:1221, in dataclass(cls, init, repr, eq, order, unsafe_hash, frozen, match_args, kw_only, slots, weakref_slot)
   1218     return wrap
   1220 # We're called as @dataclass without parens.
-> 1221 return wrap(cls)

File ~/condaenvs/bbt-hf425/lib/python3.11/dataclasses.py:1211, in dataclass.<locals>.wrap(cls)
   1210 def wrap(cls):
-> 1211     return _process_class(cls, init, repr, eq, order, unsafe_hash,
   1212                           frozen, match_args, kw_only, slots,
   1213                           weakref_slot)

File ~/condaenvs/bbt-hf425/lib/python3.11/dataclasses.py:959, in _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, match_args, kw_only, slots, weakref_slot)
    956         kw_only = True
    957     else:
    958         # Otherwise it's a field of some type.
--> 959         cls_fields.append(_get_field(cls, name, type, kw_only))
    961 for f in cls_fields:
    962     fields[f.name] = f

File ~/condaenvs/bbt-hf425/lib/python3.11/dataclasses.py:816, in _get_field(cls, a_name, a_type, default_kw_only)
    812 # For real fields, disallow mutable defaults.  Use unhashable as a proxy
    813 # indicator for mutability.  Read the __hash__ attribute from the class,
    814 # not the instance.
    815 if f._field_type is _FIELD and f.default.__class__.__hash__ is None:
--> 816     raise ValueError(f'mutable default {type(f.default)} for field '
    817                      f'{f.name} is not allowed: use default_factory')
    819 return f

ValueError: mutable default <class 'torch.distributed._shard.sharded_tensor.metadata.TensorProperties'> for field tensor_properties is not allowed: use default_factory

Thank you for your report! It seems that pytorch does not yet officially support python 3.11.

According to this comment, you can either use python 3.10 instead or install pytorch nightly release. I have tried python 3.10 and fastNLP is successfully imported.

Thanks. I am able to verify that python 3.10 doesn't throw this exception.