123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207 |
- // Copyright 2014 beego Author. All Rights Reserved.
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- package session
- import (
- "bytes"
- "crypto/cipher"
- "crypto/hmac"
- "crypto/rand"
- "crypto/sha1"
- "crypto/subtle"
- "encoding/base64"
- "encoding/gob"
- "errors"
- "fmt"
- "io"
- "strconv"
- "time"
- "github.com/astaxie/beego/utils"
- )
- func init() {
- gob.Register([]interface{}{})
- gob.Register(map[int]interface{}{})
- gob.Register(map[string]interface{}{})
- gob.Register(map[interface{}]interface{}{})
- gob.Register(map[string]string{})
- gob.Register(map[int]string{})
- gob.Register(map[int]int{})
- gob.Register(map[int]int64{})
- }
- // EncodeGob encode the obj to gob
- func EncodeGob(obj map[interface{}]interface{}) ([]byte, error) {
- for _, v := range obj {
- gob.Register(v)
- }
- buf := bytes.NewBuffer(nil)
- enc := gob.NewEncoder(buf)
- err := enc.Encode(obj)
- if err != nil {
- return []byte(""), err
- }
- return buf.Bytes(), nil
- }
- // DecodeGob decode data to map
- func DecodeGob(encoded []byte) (map[interface{}]interface{}, error) {
- buf := bytes.NewBuffer(encoded)
- dec := gob.NewDecoder(buf)
- var out map[interface{}]interface{}
- err := dec.Decode(&out)
- if err != nil {
- return nil, err
- }
- return out, nil
- }
- // generateRandomKey creates a random key with the given strength.
- func generateRandomKey(strength int) []byte {
- k := make([]byte, strength)
- if n, err := io.ReadFull(rand.Reader, k); n != strength || err != nil {
- return utils.RandomCreateBytes(strength)
- }
- return k
- }
- // Encryption -----------------------------------------------------------------
- // encrypt encrypts a value using the given block in counter mode.
- //
- // A random initialization vector (http://goo.gl/zF67k) with the length of the
- // block size is prepended to the resulting ciphertext.
- func encrypt(block cipher.Block, value []byte) ([]byte, error) {
- iv := generateRandomKey(block.BlockSize())
- if iv == nil {
- return nil, errors.New("encrypt: failed to generate random iv")
- }
- // Encrypt it.
- stream := cipher.NewCTR(block, iv)
- stream.XORKeyStream(value, value)
- // Return iv + ciphertext.
- return append(iv, value...), nil
- }
- // decrypt decrypts a value using the given block in counter mode.
- //
- // The value to be decrypted must be prepended by a initialization vector
- // (http://goo.gl/zF67k) with the length of the block size.
- func decrypt(block cipher.Block, value []byte) ([]byte, error) {
- size := block.BlockSize()
- if len(value) > size {
- // Extract iv.
- iv := value[:size]
- // Extract ciphertext.
- value = value[size:]
- // Decrypt it.
- stream := cipher.NewCTR(block, iv)
- stream.XORKeyStream(value, value)
- return value, nil
- }
- return nil, errors.New("decrypt: the value could not be decrypted")
- }
- func encodeCookie(block cipher.Block, hashKey, name string, value map[interface{}]interface{}) (string, error) {
- var err error
- var b []byte
- // 1. EncodeGob.
- if b, err = EncodeGob(value); err != nil {
- return "", err
- }
- // 2. Encrypt (optional).
- if b, err = encrypt(block, b); err != nil {
- return "", err
- }
- b = encode(b)
- // 3. Create MAC for "name|date|value". Extra pipe to be used later.
- b = []byte(fmt.Sprintf("%s|%d|%s|", name, time.Now().UTC().Unix(), b))
- h := hmac.New(sha1.New, []byte(hashKey))
- h.Write(b)
- sig := h.Sum(nil)
- // Append mac, remove name.
- b = append(b, sig...)[len(name)+1:]
- // 4. Encode to base64.
- b = encode(b)
- // Done.
- return string(b), nil
- }
- func decodeCookie(block cipher.Block, hashKey, name, value string, gcmaxlifetime int64) (map[interface{}]interface{}, error) {
- // 1. Decode from base64.
- b, err := decode([]byte(value))
- if err != nil {
- return nil, err
- }
- // 2. Verify MAC. Value is "date|value|mac".
- parts := bytes.SplitN(b, []byte("|"), 3)
- if len(parts) != 3 {
- return nil, errors.New("Decode: invalid value %v")
- }
- b = append([]byte(name+"|"), b[:len(b)-len(parts[2])]...)
- h := hmac.New(sha1.New, []byte(hashKey))
- h.Write(b)
- sig := h.Sum(nil)
- if len(sig) != len(parts[2]) || subtle.ConstantTimeCompare(sig, parts[2]) != 1 {
- return nil, errors.New("Decode: the value is not valid")
- }
- // 3. Verify date ranges.
- var t1 int64
- if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil {
- return nil, errors.New("Decode: invalid timestamp")
- }
- t2 := time.Now().UTC().Unix()
- if t1 > t2 {
- return nil, errors.New("Decode: timestamp is too new")
- }
- if t1 < t2-gcmaxlifetime {
- return nil, errors.New("Decode: expired timestamp")
- }
- // 4. Decrypt (optional).
- b, err = decode(parts[1])
- if err != nil {
- return nil, err
- }
- if b, err = decrypt(block, b); err != nil {
- return nil, err
- }
- // 5. DecodeGob.
- dst, err := DecodeGob(b)
- if err != nil {
- return nil, err
- }
- return dst, nil
- }
- // Encoding -------------------------------------------------------------------
- // encode encodes a value using base64.
- func encode(value []byte) []byte {
- encoded := make([]byte, base64.URLEncoding.EncodedLen(len(value)))
- base64.URLEncoding.Encode(encoded, value)
- return encoded
- }
- // decode decodes a cookie using base64.
- func decode(value []byte) ([]byte, error) {
- decoded := make([]byte, base64.URLEncoding.DecodedLen(len(value)))
- b, err := base64.URLEncoding.Decode(decoded, value)
- if err != nil {
- return nil, err
- }
- return decoded[:b], nil
- }
|