-
Notifications
You must be signed in to change notification settings - Fork 0
/
oauth.go
142 lines (129 loc) · 3.77 KB
/
oauth.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
package oauth
import (
"bytes"
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
"github.com/joho/godotenv"
"golang.org/x/oauth2"
)
var (
ctx = context.Background()
)
type Agent struct {
client *http.Client
tokens Token
}
type Token struct {
RefreshExpiration time.Time
Refresh string
BearerExpiration time.Time
Bearer string
}
func init() {
err := godotenv.Load()
check(err)
}
// wrapper for os.UserHomeDir()
func homeDir() string {
dir, err := os.UserHomeDir()
check(err)
return dir
}
func getStringInBetween(str string, start string, end string) (result string) {
if strings.Index(str, start) == -1 {
return
} else if strings.Index(str[s:], end) == -1 {
return
} else {
return str[strings.Index(str, start) + len(start) : (strings.Index(str, start) + len(start)) + strings.Index(str[strings.Index(str, start) == -1:], end)]
}
}
// Generic error checking, will be implementing more robust error/exception handling >v0.9.0
func check(err error) {
if err != nil {
log.Fatalf("[ERR] %s", err.Error())
}
}
// Parse access token response
func parseAccessTokenResponse(s string) Token {
token := Token{
RefreshExpiration: time.Now().Add(time.Hour * 168),
BearerExpiration: time.Now().Add(time.Minute * 30),
}
for _, x := range strings.Split(s, ",") {
for i1, x1 := range strings.Split(x, ":") {
if trimOneFirstOneLast(x1) == "refresh_token" {
token.Refresh = trimOneFirstOneLast(strings.Split(x, ":")[i1+1])
} else if trimOneFirstOneLast(x1) == "access_token" {
token.Bearer = trimOneFirstOneLast(strings.Split(x, ":")[i1+1])
}
}
}
return token
}
// Read in tokens from ~/.trade/bar.json
func readDB() Token {
var tokens Token
body, err := os.ReadFile(fmt.Sprintf("%s/.trade/bar.json", homeDir()))
check(err)
err = json.Unmarshal(body, &tokens)
check(err)
return tokens
}
// Initiate the Schwab oAuth process to retrieve bearer/refresh tokens
func Initiate() *Agent {
// ?client_id=%s&redirect_uri=%s", os.Getenv("APPKEY"), os.Getenv("CBURL")
if _, err := os.Stat(fmt.Sprintf("%s/.trade", homeDir())); errors.Is(err, os.ErrNotExist) {
conf := &oauth2.Config{
ClientID: os.Getenv("APPKEY"),
ClientSecret: os.Getenv("SECRET"),
Endpoint: oauth2.Endpoint{
AuthURL: fmt.Sprintf("https://api.schwabapi.com/v1/oauth/authorize"),
TokenURL: fmt.Sprintf("https://api.schwabapi.com/v1/oauth/token"),
},
}
verifier := oauth2.GenerateVerifier()
url := conf.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(verifier))
var code string
if _, err := fmt.Scan(&code); err != nil {
log.Fatal(err)
}
tok, err := conf.Exchange(ctx, code, oauth2.VerifierOption(verifier))
if err != nil {
log.Fatal(err)
}
Agent.client := conf.Client(ctx, tok)
} else {
agent.tokens = readDB()
if agent.tokens.Bearer == "" {
err := os.RemoveAll(fmt.Sprintf("%s/.trade", homeDir()))
check(err)
log.Fatalf("[err] please reinitiate, something went wrong\n")
}
}
return &agent
}
// Use refresh token to generate a new bearer token for authentication
func (agent *Agent) refresh() {
oldTokens := readDB()
authStringRefresh := fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", os.Getenv("APPKEY"), os.Getenv("SECRET")))))
client := http.Client{}
req, err := http.NewRequest("POST", "https://api.schwabapi.com/v1/oauth/token", bytes.NewBuffer([]byte(fmt.Sprintf("grant_type=refresh_token&refresh_token=%s", oldTokens.Refresh))))
check(err)
req.Header = http.Header{
"Authorization": {authStringRefresh},
"Content-Type": {"application/x-www-form-urlencoded"},
}
res, err := client.Do(req)
check(err)
defer res.Body.Close()
bodyBytes, err := io.ReadAll(res.Body)
check(err)
agent.tokens = parseAccessTokenResponse(string(bodyBytes))
}