Tweet
Logo
    sklearn の OrdinalEncoder の挙動
    sklearn の OrdinalEncoder の挙動

    sklearn の OrdinalEncoder の挙動

    OrdinalEncoder の挙動についてのメモ

    sklearn.preprocessing.OrdinalEncoder - scikit-learn 0.24.2 documentation

    The input to this transformer should be an array-like of integers or strings, denoting the values taken on by categorical (discrete) features. The features are converted to ordinal integers. This results in a single column of integers (0 to n_categories - 1) per feature.

    scikit-learn.org

    sklearn.preprocessing.OrdinalEncoder - scikit-learn 0.24.2 documentation

    カテゴリカル feature を int にします

    >>> df = pd.DataFrame([1,2,3,2,1])
    >>> oe = OrdinalEncoder(categories='auto', dtype=np.int64)
    >>> oe.fit(df)
    >>> oe.transform(df)
    array([[0],
           [1],
           [2],
           [1],
           [0]])
    
    >>> oe.categories_
    [array([1, 2, 3])]

    カテゴリーの型は全て同じである必要がある

    必要があれば value を str に変換するなど必要

    df = pd.DataFrame([1,2,3,2,1]).astype('str') # astype(str) すれば問題ない
    categories = [['1', '2', '3', 'na']]
    oe = OrdinalEncoder(categories=categories, dtype=np.int64)
    oe.fit(df)
    oe.transform(df)
    © 2025 DROBE All rights reserved.
    >>> df = pd.DataFrame([1,2,3,2,1])
    >>> categories = [[1,2,3, 'na']]
    >>> oe = OrdinalEncoder(categories=categories, dtype=np.int64)
    >>> oe.fit(df)
    ---------------------------------------------------------------------------
    ValueError                                Traceback (most recent call last)
    <ipython-input-92-fbe59670fc3d> in <module>()
          2 categories = [[1,2,3, 'na']]
          3 oe = OrdinalEncoder(categories=categories, dtype=np.int64)
    ----> 4 oe.fit(df)
          5 oe.transform(df)
    
    1 frames
    /usr/local/lib/python3.7/dist-packages/sklearn/preprocessing/_encoders.py in _fit(self, X, handle_unknown)
         86                 cats = _encode(Xi)
         87             else:
    ---> 88                 cats = np.array(self.categories[i], dtype=Xi.dtype)
         89                 if Xi.dtype != object:
         90                     if not np.all(np.sort(cats) == cats):
    
    ValueError: invalid literal for int() with base 10: 'na'