[Python]회귀(3D plot)

반응형
    반응형

    이어서 회귀를 하겠습니다. 이전 포스팅에서는 메커니즘을 설명하기 위해 이미 알려진 데이터로 했는데요.

    이번에는 데이터가 3차원인 경우에 대해 회귀를 진행해보겠습니다. 

    행렬을 이용해 구성할 수 있습니다. 

     

    실제로 회귀는 아는 데이터가 아니라 잘 모르는 데이터의 모양을 보고 진행을 합니다. 

    하지만 메커니즘을 위해 데이터의 모형을 다 안다는 전제하에 연습용으로 그려내겠습니다.

     

    셋팅

    먼저 데이터 값인 f(z)를 정의하겠습니다. 

    그리고 x, y 를 meshgrid를 이용해 격자로 펼치겠습니다.

    def f(z):
        x, y =z
        return np.sin(x)+0.25*x+np.sqrt(y)
        
    x = np.linspace(0,10,20)
    y = np.linspace(0,10,20)
    X, Y = np.meshgrid(x,y)
    
    Z = f((X,Y))
    x= X.flatten()
    y = Y.flatten()

     

    이걸 그림으로 나타나면 다음과 같이 나옵니다.

    from mpl_toolkits.mplot3d import Axes3D
    fig = plt.figure(figsize=(10,6))
    ax = fig.gca(projection='3d')
    surf = ax.plot_surface(X,Y,Z,rstride=2, cstride=2, cmap='coolwarm',linewidth=0.5,antialiased =True)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('f(x,y)')
    fig.colorbar(surf,shrink=0.5,aspect=5)

     

    3D 회귀

    3차원이라고 확 달라지는 건 없습니다. 이전 포스팅에서 기저함수를 넣었던 것처럼 여기서도 넣어보겠습니다.

    기저를 정했으면 matrix @ a = f(x,y) 가 되는 a를 np.linalg.lstsq()로 찾아낸 후 다시 값을 찾아 그림을 비교해보겠습니다.

    matrix = np.zeros((len(x),5))
    matrix[:,4] = np.sqrt(y)
    matrix[:,3] = np.sin(x)
    matrix[:,2] = y
    matrix[:,1] = x
    matrix[:,0] = 1
    
    reg = np.linalg.lstsq(matrix, f((x,y)),rcond=None)[0]
    RZ = np.dot(matrix,reg).reshape((20,20))
    
    fig = plt.figure(figsize = (10,6))
    ax = fig.gca(projection='3d')
    
    surf1 = ax.plot_surface(X, Y, Z, rstride=2, cstride=2, alpha=0.7, cmap=mpl.cm.coolwarm, linewidth=0.3, antialiased=True)
    surf2 = ax.plot_wireframe(X, Y, RZ, rstride=2, cstride=2,color='blue', label='regression')
    
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('f(x,y)')
    ax.legend()
    fig.colorbar(surf,shrink=0.5,aspect=5)

     

    원래 함수를 그냥 두면 regression 선이 잘 안 보여서 투명도(alpha)를 약간 주었습니다. 

    보시다시피 regression 선을 보면 그림에 맞게 어느정도 그려내는 걸 볼 수 있습니다. 

    회귀가 나름 잘 이루어졌습니다. 

     

     

    관련 포스팅

    [데이터 사이언스/머신러닝 딮러닝] - [Python] 회귀(Regression)

    [Python/Numpy] - [Numpy]격자 그리드 만들기(meshgrid)

    [Python/그래프 그리기] - [matplotlib] 3D plot

     

    댓글

    Designed by JB FACTORY

    ....