# -*- encoding:utf-8 -*- ''' @Author : dingjiawen @Date : 2023/11/15 13:18 @Usage : @Desc : ''' import torch_dct as dct import numpy as np import torch a = np.load("a.npy") b = np.load("b.npy") a = torch.tensor(a) c = dct.dct(x=a) d = [] for i in range(a.shape[1]): d.append(dct.dct(a[:,i,:])) d = torch.stack(d,dim=1) torch.nn.Linear() print("a.shape:",a.shape) print("a:",a) print("b:",b) print("c:",c) print("d:",d)