autocert_test.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606
  1. // Copyright 2016 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package autocert
  5. import (
  6. "context"
  7. "crypto"
  8. "crypto/ecdsa"
  9. "crypto/elliptic"
  10. "crypto/rand"
  11. "crypto/rsa"
  12. "crypto/tls"
  13. "crypto/x509"
  14. "crypto/x509/pkix"
  15. "encoding/base64"
  16. "encoding/json"
  17. "fmt"
  18. "html/template"
  19. "io"
  20. "math/big"
  21. "net/http"
  22. "net/http/httptest"
  23. "reflect"
  24. "sync"
  25. "testing"
  26. "time"
  27. "golang.org/x/crypto/acme"
  28. )
  29. var discoTmpl = template.Must(template.New("disco").Parse(`{
  30. "new-reg": "{{.}}/new-reg",
  31. "new-authz": "{{.}}/new-authz",
  32. "new-cert": "{{.}}/new-cert"
  33. }`))
  34. var authzTmpl = template.Must(template.New("authz").Parse(`{
  35. "status": "pending",
  36. "challenges": [
  37. {
  38. "uri": "{{.}}/challenge/1",
  39. "type": "tls-sni-01",
  40. "token": "token-01"
  41. },
  42. {
  43. "uri": "{{.}}/challenge/2",
  44. "type": "tls-sni-02",
  45. "token": "token-02"
  46. }
  47. ]
  48. }`))
  49. type memCache struct {
  50. mu sync.Mutex
  51. keyData map[string][]byte
  52. }
  53. func (m *memCache) Get(ctx context.Context, key string) ([]byte, error) {
  54. m.mu.Lock()
  55. defer m.mu.Unlock()
  56. v, ok := m.keyData[key]
  57. if !ok {
  58. return nil, ErrCacheMiss
  59. }
  60. return v, nil
  61. }
  62. func (m *memCache) Put(ctx context.Context, key string, data []byte) error {
  63. m.mu.Lock()
  64. defer m.mu.Unlock()
  65. m.keyData[key] = data
  66. return nil
  67. }
  68. func (m *memCache) Delete(ctx context.Context, key string) error {
  69. m.mu.Lock()
  70. defer m.mu.Unlock()
  71. delete(m.keyData, key)
  72. return nil
  73. }
  74. func newMemCache() *memCache {
  75. return &memCache{
  76. keyData: make(map[string][]byte),
  77. }
  78. }
  79. func dummyCert(pub interface{}, san ...string) ([]byte, error) {
  80. return dateDummyCert(pub, time.Now(), time.Now().Add(90*24*time.Hour), san...)
  81. }
  82. func dateDummyCert(pub interface{}, start, end time.Time, san ...string) ([]byte, error) {
  83. // use EC key to run faster on 386
  84. key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
  85. if err != nil {
  86. return nil, err
  87. }
  88. t := &x509.Certificate{
  89. SerialNumber: big.NewInt(1),
  90. NotBefore: start,
  91. NotAfter: end,
  92. BasicConstraintsValid: true,
  93. KeyUsage: x509.KeyUsageKeyEncipherment,
  94. DNSNames: san,
  95. }
  96. if pub == nil {
  97. pub = &key.PublicKey
  98. }
  99. return x509.CreateCertificate(rand.Reader, t, t, pub, key)
  100. }
  101. func decodePayload(v interface{}, r io.Reader) error {
  102. var req struct{ Payload string }
  103. if err := json.NewDecoder(r).Decode(&req); err != nil {
  104. return err
  105. }
  106. payload, err := base64.RawURLEncoding.DecodeString(req.Payload)
  107. if err != nil {
  108. return err
  109. }
  110. return json.Unmarshal(payload, v)
  111. }
  112. func TestGetCertificate(t *testing.T) {
  113. man := &Manager{Prompt: AcceptTOS}
  114. defer man.stopRenew()
  115. hello := &tls.ClientHelloInfo{ServerName: "example.org"}
  116. testGetCertificate(t, man, "example.org", hello)
  117. }
  118. func TestGetCertificate_trailingDot(t *testing.T) {
  119. man := &Manager{Prompt: AcceptTOS}
  120. defer man.stopRenew()
  121. hello := &tls.ClientHelloInfo{ServerName: "example.org."}
  122. testGetCertificate(t, man, "example.org", hello)
  123. }
  124. func TestGetCertificate_ForceRSA(t *testing.T) {
  125. man := &Manager{
  126. Prompt: AcceptTOS,
  127. Cache: newMemCache(),
  128. ForceRSA: true,
  129. }
  130. defer man.stopRenew()
  131. hello := &tls.ClientHelloInfo{ServerName: "example.org"}
  132. testGetCertificate(t, man, "example.org", hello)
  133. cert, err := man.cacheGet(context.Background(), "example.org")
  134. if err != nil {
  135. t.Fatalf("man.cacheGet: %v", err)
  136. }
  137. if _, ok := cert.PrivateKey.(*rsa.PrivateKey); !ok {
  138. t.Errorf("cert.PrivateKey is %T; want *rsa.PrivateKey", cert.PrivateKey)
  139. }
  140. }
  141. func TestGetCertificate_nilPrompt(t *testing.T) {
  142. man := &Manager{}
  143. defer man.stopRenew()
  144. url, finish := startACMEServerStub(t, man, "example.org")
  145. defer finish()
  146. key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
  147. if err != nil {
  148. t.Fatal(err)
  149. }
  150. man.Client = &acme.Client{
  151. Key: key,
  152. DirectoryURL: url,
  153. }
  154. hello := &tls.ClientHelloInfo{ServerName: "example.org"}
  155. if _, err := man.GetCertificate(hello); err == nil {
  156. t.Error("got certificate for example.org; wanted error")
  157. }
  158. }
  159. func TestGetCertificate_expiredCache(t *testing.T) {
  160. // Make an expired cert and cache it.
  161. pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
  162. if err != nil {
  163. t.Fatal(err)
  164. }
  165. tmpl := &x509.Certificate{
  166. SerialNumber: big.NewInt(1),
  167. Subject: pkix.Name{CommonName: "example.org"},
  168. NotAfter: time.Now(),
  169. }
  170. pub, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &pk.PublicKey, pk)
  171. if err != nil {
  172. t.Fatal(err)
  173. }
  174. tlscert := &tls.Certificate{
  175. Certificate: [][]byte{pub},
  176. PrivateKey: pk,
  177. }
  178. man := &Manager{Prompt: AcceptTOS, Cache: newMemCache()}
  179. defer man.stopRenew()
  180. if err := man.cachePut(context.Background(), "example.org", tlscert); err != nil {
  181. t.Fatalf("man.cachePut: %v", err)
  182. }
  183. // The expired cached cert should trigger a new cert issuance
  184. // and return without an error.
  185. hello := &tls.ClientHelloInfo{ServerName: "example.org"}
  186. testGetCertificate(t, man, "example.org", hello)
  187. }
  188. func TestGetCertificate_failedAttempt(t *testing.T) {
  189. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  190. w.WriteHeader(http.StatusBadRequest)
  191. }))
  192. defer ts.Close()
  193. const example = "example.org"
  194. d := createCertRetryAfter
  195. f := testDidRemoveState
  196. defer func() {
  197. createCertRetryAfter = d
  198. testDidRemoveState = f
  199. }()
  200. createCertRetryAfter = 0
  201. done := make(chan struct{})
  202. testDidRemoveState = func(domain string) {
  203. if domain != example {
  204. t.Errorf("testDidRemoveState: domain = %q; want %q", domain, example)
  205. }
  206. close(done)
  207. }
  208. key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
  209. if err != nil {
  210. t.Fatal(err)
  211. }
  212. man := &Manager{
  213. Prompt: AcceptTOS,
  214. Client: &acme.Client{
  215. Key: key,
  216. DirectoryURL: ts.URL,
  217. },
  218. }
  219. defer man.stopRenew()
  220. hello := &tls.ClientHelloInfo{ServerName: example}
  221. if _, err := man.GetCertificate(hello); err == nil {
  222. t.Error("GetCertificate: err is nil")
  223. }
  224. select {
  225. case <-time.After(5 * time.Second):
  226. t.Errorf("took too long to remove the %q state", example)
  227. case <-done:
  228. man.stateMu.Lock()
  229. defer man.stateMu.Unlock()
  230. if v, exist := man.state[example]; exist {
  231. t.Errorf("state exists for %q: %+v", example, v)
  232. }
  233. }
  234. }
  235. // startACMEServerStub runs an ACME server
  236. // The domain argument is the expected domain name of a certificate request.
  237. func startACMEServerStub(t *testing.T, man *Manager, domain string) (url string, finish func()) {
  238. // echo token-02 | shasum -a 256
  239. // then divide result in 2 parts separated by dot
  240. tokenCertName := "4e8eb87631187e9ff2153b56b13a4dec.13a35d002e485d60ff37354b32f665d9.token.acme.invalid"
  241. verifyTokenCert := func() {
  242. hello := &tls.ClientHelloInfo{ServerName: tokenCertName}
  243. _, err := man.GetCertificate(hello)
  244. if err != nil {
  245. t.Errorf("verifyTokenCert: GetCertificate(%q): %v", tokenCertName, err)
  246. return
  247. }
  248. }
  249. // ACME CA server stub
  250. var ca *httptest.Server
  251. ca = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  252. w.Header().Set("replay-nonce", "nonce")
  253. if r.Method == "HEAD" {
  254. // a nonce request
  255. return
  256. }
  257. switch r.URL.Path {
  258. // discovery
  259. case "/":
  260. if err := discoTmpl.Execute(w, ca.URL); err != nil {
  261. t.Errorf("discoTmpl: %v", err)
  262. }
  263. // client key registration
  264. case "/new-reg":
  265. w.Write([]byte("{}"))
  266. // domain authorization
  267. case "/new-authz":
  268. w.Header().Set("location", ca.URL+"/authz/1")
  269. w.WriteHeader(http.StatusCreated)
  270. if err := authzTmpl.Execute(w, ca.URL); err != nil {
  271. t.Errorf("authzTmpl: %v", err)
  272. }
  273. // accept tls-sni-02 challenge
  274. case "/challenge/2":
  275. verifyTokenCert()
  276. w.Write([]byte("{}"))
  277. // authorization status
  278. case "/authz/1":
  279. w.Write([]byte(`{"status": "valid"}`))
  280. // cert request
  281. case "/new-cert":
  282. var req struct {
  283. CSR string `json:"csr"`
  284. }
  285. decodePayload(&req, r.Body)
  286. b, _ := base64.RawURLEncoding.DecodeString(req.CSR)
  287. csr, err := x509.ParseCertificateRequest(b)
  288. if err != nil {
  289. t.Errorf("new-cert: CSR: %v", err)
  290. }
  291. if csr.Subject.CommonName != domain {
  292. t.Errorf("CommonName in CSR = %q; want %q", csr.Subject.CommonName, domain)
  293. }
  294. der, err := dummyCert(csr.PublicKey, domain)
  295. if err != nil {
  296. t.Errorf("new-cert: dummyCert: %v", err)
  297. }
  298. chainUp := fmt.Sprintf("<%s/ca-cert>; rel=up", ca.URL)
  299. w.Header().Set("link", chainUp)
  300. w.WriteHeader(http.StatusCreated)
  301. w.Write(der)
  302. // CA chain cert
  303. case "/ca-cert":
  304. der, err := dummyCert(nil, "ca")
  305. if err != nil {
  306. t.Errorf("ca-cert: dummyCert: %v", err)
  307. }
  308. w.Write(der)
  309. default:
  310. t.Errorf("unrecognized r.URL.Path: %s", r.URL.Path)
  311. }
  312. }))
  313. finish = func() {
  314. ca.Close()
  315. // make sure token cert was removed
  316. cancel := make(chan struct{})
  317. done := make(chan struct{})
  318. go func() {
  319. defer close(done)
  320. tick := time.NewTicker(100 * time.Millisecond)
  321. defer tick.Stop()
  322. for {
  323. hello := &tls.ClientHelloInfo{ServerName: tokenCertName}
  324. if _, err := man.GetCertificate(hello); err != nil {
  325. return
  326. }
  327. select {
  328. case <-tick.C:
  329. case <-cancel:
  330. return
  331. }
  332. }
  333. }()
  334. select {
  335. case <-done:
  336. case <-time.After(5 * time.Second):
  337. close(cancel)
  338. t.Error("token cert was not removed")
  339. <-done
  340. }
  341. }
  342. return ca.URL, finish
  343. }
  344. // tests man.GetCertificate flow using the provided hello argument.
  345. // The domain argument is the expected domain name of a certificate request.
  346. func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.ClientHelloInfo) {
  347. url, finish := startACMEServerStub(t, man, domain)
  348. defer finish()
  349. // use EC key to run faster on 386
  350. key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
  351. if err != nil {
  352. t.Fatal(err)
  353. }
  354. man.Client = &acme.Client{
  355. Key: key,
  356. DirectoryURL: url,
  357. }
  358. // simulate tls.Config.GetCertificate
  359. var tlscert *tls.Certificate
  360. done := make(chan struct{})
  361. go func() {
  362. tlscert, err = man.GetCertificate(hello)
  363. close(done)
  364. }()
  365. select {
  366. case <-time.After(time.Minute):
  367. t.Fatal("man.GetCertificate took too long to return")
  368. case <-done:
  369. }
  370. if err != nil {
  371. t.Fatalf("man.GetCertificate: %v", err)
  372. }
  373. // verify the tlscert is the same we responded with from the CA stub
  374. if len(tlscert.Certificate) == 0 {
  375. t.Fatal("len(tlscert.Certificate) is 0")
  376. }
  377. cert, err := x509.ParseCertificate(tlscert.Certificate[0])
  378. if err != nil {
  379. t.Fatalf("x509.ParseCertificate: %v", err)
  380. }
  381. if len(cert.DNSNames) == 0 || cert.DNSNames[0] != domain {
  382. t.Errorf("cert.DNSNames = %v; want %q", cert.DNSNames, domain)
  383. }
  384. }
  385. func TestAccountKeyCache(t *testing.T) {
  386. m := Manager{Cache: newMemCache()}
  387. ctx := context.Background()
  388. k1, err := m.accountKey(ctx)
  389. if err != nil {
  390. t.Fatal(err)
  391. }
  392. k2, err := m.accountKey(ctx)
  393. if err != nil {
  394. t.Fatal(err)
  395. }
  396. if !reflect.DeepEqual(k1, k2) {
  397. t.Errorf("account keys don't match: k1 = %#v; k2 = %#v", k1, k2)
  398. }
  399. }
  400. func TestCache(t *testing.T) {
  401. privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
  402. if err != nil {
  403. t.Fatal(err)
  404. }
  405. tmpl := &x509.Certificate{
  406. SerialNumber: big.NewInt(1),
  407. Subject: pkix.Name{CommonName: "example.org"},
  408. NotAfter: time.Now().Add(time.Hour),
  409. }
  410. pub, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &privKey.PublicKey, privKey)
  411. if err != nil {
  412. t.Fatal(err)
  413. }
  414. tlscert := &tls.Certificate{
  415. Certificate: [][]byte{pub},
  416. PrivateKey: privKey,
  417. }
  418. man := &Manager{Cache: newMemCache()}
  419. defer man.stopRenew()
  420. ctx := context.Background()
  421. if err := man.cachePut(ctx, "example.org", tlscert); err != nil {
  422. t.Fatalf("man.cachePut: %v", err)
  423. }
  424. res, err := man.cacheGet(ctx, "example.org")
  425. if err != nil {
  426. t.Fatalf("man.cacheGet: %v", err)
  427. }
  428. if res == nil {
  429. t.Fatal("res is nil")
  430. }
  431. }
  432. func TestHostWhitelist(t *testing.T) {
  433. policy := HostWhitelist("example.com", "example.org", "*.example.net")
  434. tt := []struct {
  435. host string
  436. allow bool
  437. }{
  438. {"example.com", true},
  439. {"example.org", true},
  440. {"one.example.com", false},
  441. {"two.example.org", false},
  442. {"three.example.net", false},
  443. {"dummy", false},
  444. }
  445. for i, test := range tt {
  446. err := policy(nil, test.host)
  447. if err != nil && test.allow {
  448. t.Errorf("%d: policy(%q): %v; want nil", i, test.host, err)
  449. }
  450. if err == nil && !test.allow {
  451. t.Errorf("%d: policy(%q): nil; want an error", i, test.host)
  452. }
  453. }
  454. }
  455. func TestValidCert(t *testing.T) {
  456. key1, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
  457. if err != nil {
  458. t.Fatal(err)
  459. }
  460. key2, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
  461. if err != nil {
  462. t.Fatal(err)
  463. }
  464. key3, err := rsa.GenerateKey(rand.Reader, 512)
  465. if err != nil {
  466. t.Fatal(err)
  467. }
  468. cert1, err := dummyCert(key1.Public(), "example.org")
  469. if err != nil {
  470. t.Fatal(err)
  471. }
  472. cert2, err := dummyCert(key2.Public(), "example.org")
  473. if err != nil {
  474. t.Fatal(err)
  475. }
  476. cert3, err := dummyCert(key3.Public(), "example.org")
  477. if err != nil {
  478. t.Fatal(err)
  479. }
  480. now := time.Now()
  481. early, err := dateDummyCert(key1.Public(), now.Add(time.Hour), now.Add(2*time.Hour), "example.org")
  482. if err != nil {
  483. t.Fatal(err)
  484. }
  485. expired, err := dateDummyCert(key1.Public(), now.Add(-2*time.Hour), now.Add(-time.Hour), "example.org")
  486. if err != nil {
  487. t.Fatal(err)
  488. }
  489. tt := []struct {
  490. domain string
  491. key crypto.Signer
  492. cert [][]byte
  493. ok bool
  494. }{
  495. {"example.org", key1, [][]byte{cert1}, true},
  496. {"example.org", key3, [][]byte{cert3}, true},
  497. {"example.org", key1, [][]byte{cert1, cert2, cert3}, true},
  498. {"example.org", key1, [][]byte{cert1, {1}}, false},
  499. {"example.org", key1, [][]byte{{1}}, false},
  500. {"example.org", key1, [][]byte{cert2}, false},
  501. {"example.org", key2, [][]byte{cert1}, false},
  502. {"example.org", key1, [][]byte{cert3}, false},
  503. {"example.org", key3, [][]byte{cert1}, false},
  504. {"example.net", key1, [][]byte{cert1}, false},
  505. {"example.org", key1, [][]byte{early}, false},
  506. {"example.org", key1, [][]byte{expired}, false},
  507. }
  508. for i, test := range tt {
  509. leaf, err := validCert(test.domain, test.cert, test.key)
  510. if err != nil && test.ok {
  511. t.Errorf("%d: err = %v", i, err)
  512. }
  513. if err == nil && !test.ok {
  514. t.Errorf("%d: err is nil", i)
  515. }
  516. if err == nil && test.ok && leaf == nil {
  517. t.Errorf("%d: leaf is nil", i)
  518. }
  519. }
  520. }
  521. type cacheGetFunc func(ctx context.Context, key string) ([]byte, error)
  522. func (f cacheGetFunc) Get(ctx context.Context, key string) ([]byte, error) {
  523. return f(ctx, key)
  524. }
  525. func (f cacheGetFunc) Put(ctx context.Context, key string, data []byte) error {
  526. return fmt.Errorf("unsupported Put of %q = %q", key, data)
  527. }
  528. func (f cacheGetFunc) Delete(ctx context.Context, key string) error {
  529. return fmt.Errorf("unsupported Delete of %q", key)
  530. }
  531. func TestManagerGetCertificateBogusSNI(t *testing.T) {
  532. m := Manager{
  533. Prompt: AcceptTOS,
  534. Cache: cacheGetFunc(func(ctx context.Context, key string) ([]byte, error) {
  535. return nil, fmt.Errorf("cache.Get of %s", key)
  536. }),
  537. }
  538. tests := []struct {
  539. name string
  540. wantErr string
  541. }{
  542. {"foo.com", "cache.Get of foo.com"},
  543. {"foo.com.", "cache.Get of foo.com"},
  544. {`a\b.com`, "acme/autocert: server name contains invalid character"},
  545. {`a/b.com`, "acme/autocert: server name contains invalid character"},
  546. {"", "acme/autocert: missing server name"},
  547. {"foo", "acme/autocert: server name component count invalid"},
  548. {".foo", "acme/autocert: server name component count invalid"},
  549. {"foo.", "acme/autocert: server name component count invalid"},
  550. {"fo.o", "cache.Get of fo.o"},
  551. }
  552. for _, tt := range tests {
  553. _, err := m.GetCertificate(&tls.ClientHelloInfo{ServerName: tt.name})
  554. got := fmt.Sprint(err)
  555. if got != tt.wantErr {
  556. t.Errorf("GetCertificate(SNI = %q) = %q; want %q", tt.name, got, tt.wantErr)
  557. }
  558. }
  559. }