We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hello, I am a new learner. Could you please upload the .py file to run the test set? That will be very helpful, thanks a lot.
The text was updated successfully, but these errors were encountered:
Thank you for the question. You can refer to the following code.
import pickle import numpy as np from sa_convlstm import SAConvLSTM import torch import os import shutil def main(): ### load the config os.mkdir('result') with open('config_train.pkl', 'rb') as config_test: configs = pickle.load(config_test) model = SAConvLSTM(configs.input_dim, configs.hidden_dim, configs.d_attn, configs.kernel_size).to(configs.device) model.load_state_dict(torch.load('checkpoint.chk')['net']) model.eval() for i, file in enumerate(os.listdir("tcdata/enso_final_test_data_B/")): test_x = np.load("tcdata/enso_final_test_data_B/" + file) # (12, lat, lon, 4) 4 is in order: sst, t300, ua, va test_x = test_x[:, :, range(19, 67), 0] # (12, 24, 48) select only sst and certain longitudes test_x = test_x[None] # (1, 12, 24, 48) with torch.no_grad(): _, pred = model(torch.from_numpy(test_x).float()[:, :, None], train=False) # (1, 24) np.save('result/' + file, pred[0].cpu().numpy()) # (24,) shutil.make_archive('result', 'zip', '/result') if __name__ == '__main__': main()
Sorry, something went wrong.
Thanks bro!
No branches or pull requests
Hello, I am a new learner. Could you please upload the .py file to run the test set? That will be very helpful, thanks a lot.
The text was updated successfully, but these errors were encountered: