• haoyanbin's avatar
    1 · 1bd9d1c6
    haoyanbin authored
    1bd9d1c6
wxpay.go 7.77 KB
package wx

import (
	"bytes"
	"crypto/md5"
	"crypto/tls"
	"encoding/base64"
	"encoding/hex"
	"encoding/pem"
	"encoding/xml"
	"errors"
	"io/ioutil"
	"math/rand"
	"net/http"
	"net/url"
	"sort"
	"strings"
	"time"

	"golang.org/x/crypto/pkcs12"
)

type Client struct {
	appId        string //微信支付申请对应的公众号的APPID
	apiKey       string //https://pay.weixin.qq.com 帐户设置-安全设置-API安全-API密钥-设置API密钥
	mchId        string //微信支付商户号 PartnerID 通过微信支付商户资料审核后邮件发送
	Client       *http.Client
	tlsClient    *http.Client
	apiDomain    string
	NotifyURL    string
	isProduction bool
}

func New(appId, apiKey, mchId string, isProduction bool) (client *Client) {
	client = &Client{}
	client.appId = appId
	client.mchId = mchId
	client.apiKey = apiKey
	client.Client = http.DefaultClient
	client.isProduction = isProduction
	if isProduction {
		client.apiDomain = kProductionURL
	} else {
		client.apiDomain = kSandboxURL
	}
	return client
}

func initTLSClient(cert []byte, password string) (tlsClient *http.Client, err error) {
	if len(cert) > 0 {
		cert, err := pkcs12ToPem(cert, password)
		if err != nil {
			return nil, err
		}

		config := &tls.Config{
			Certificates: []tls.Certificate{cert},
		}

		transport := &http.Transport{
			TLSClientConfig:    config,
			DisableCompression: true,
		}

		tlsClient = &http.Client{Transport: transport}
	}

	return tlsClient, err
}

func (this *Client) LoadCert(path string) (err error) {
	if len(path) == 0 {
		return ErrNotFoundCertFile
	}

	cert, err := ioutil.ReadFile(path)
	if err != nil {
		return err
	}

	tlsClient, err := initTLSClient(cert, this.mchId)
	if err != nil {
		return err
	}
	this.tlsClient = tlsClient
	return nil
}

func (this *Client) LoadCertFromBase64(cert string) (err error) {
	if len(cert) == 0 {
		return ErrNotFoundCertFile
	}

	p12, err := base64.StdEncoding.DecodeString(cert)
	if err != nil {
		return err
	}

	tlsClient, err := initTLSClient(p12, this.mchId)
	if err != nil {
		return err
	}
	this.tlsClient = tlsClient
	return nil
}

func (this *Client) URLValues(param Param, key string) (value url.Values, err error) {
	var p = param.Params()
	if appId := p.Get("appid"); appId == "" {
		p.Set("appid", this.appId)
	}
	p.Set("mch_id", this.mchId)
	p.Set("nonce_str", GetNonceStr())

	if _, ok := p["notify_url"]; ok == false {
		if len(this.NotifyURL) > 0 {
			p.Set("notify_url", this.NotifyURL)
		}
	}

	var sign = SignMD5(p, key)
	p.Set("sign", sign)
	return p, nil
}

func (this *Client) doRequest(method, url string, param Param, result interface{}) (err error) {
	_, err = this.doRequestWithClient(this.Client, method, url, param, result)
	return err
}

func (this *Client) doRequestWithTLS(method, url string, param Param, result interface{}) (err error) {
	if this.tlsClient == nil {
		return ErrNotFoundTLSClient
	}
	_, err = this.doRequestWithClient(this.tlsClient, method, url, param, result)
	return err
}

func (this *Client) doRequestWithClient(client *http.Client, method, url string, param Param, result interface{}) (body []byte, err error) {
	key, err := this.getKey()
	if err != nil {
		return nil, err
	}

	p, err := this.URLValues(param, key)
	if err != nil {
		return nil, err
	}
	req, err := http.NewRequest(method, url, strings.NewReader(URLValueToXML(p)))
	if err != nil {
		return nil, err
	}
	req.Header.Set("Accept", "application/xml")
	req.Header.Set("Content-Type", "application/xml;charset=utf-8")

	resp, err := client.Do(req)
	if resp != nil {
		defer resp.Body.Close()
	}
	if err != nil {
		return nil, err
	}

	data, err := ioutil.ReadAll(resp.Body)
	if err != nil {
		return nil, err
	}

	if _, err := VerifyResponseData(data, key); err != nil {
		return nil, err
	}

	err = xml.Unmarshal(data, result)

	return data, err
}

func (this *Client) DoRequest(method, url string, param Param, result interface{}) (err error) {
	return this.doRequest(method, url, param, result)
}

func (this *Client) getKey() (key string, err error) {
	if this.isProduction == false {
		key, err = this.getSignKey(this.apiKey)
		if err != nil {
			return "", err
		}
	} else {
		key = this.apiKey
	}
	return key, err
}

func (this *Client) SignMD5(param url.Values) (sign string) {
	return SignMD5(param, this.apiKey)
}

func (this *Client) getSignKey(apiKey string) (key string, err error) {
	var p = make(url.Values)
	p.Set("mch_id", this.mchId)
	p.Set("nonce_str", GetNonceStr())

	var sign = SignMD5(p, apiKey)
	p.Set("sign", sign)

	req, err := http.NewRequest("POST", "https://api.mch.weixin.qq.com/sandboxnew/pay/getsignkey", strings.NewReader(URLValueToXML(p)))
	if err != nil {
		return "", err
	}
	req.Header.Set("Accept", "application/xml")
	req.Header.Set("Content-Type", "application/xml;charset=utf-8")

	resp, err := this.Client.Do(req)
	if resp != nil {
		defer resp.Body.Close()
	}
	if err != nil {
		return "", err
	}

	data, err := ioutil.ReadAll(resp.Body)
	if err != nil {
		return "", err
	}
	var signKey *GetSignKeyRsp
	if err = xml.Unmarshal(data, &signKey); err != nil {
		return "", err
	}
	return signKey.SandboxSignKey, nil
}

func (this *Client) BuildAPI(paths ...string) string {
	var path = this.apiDomain
	for _, p := range paths {
		p = strings.TrimSpace(p)
		if len(p) > 0 {
			if strings.HasSuffix(path, "/") {
				path = path + p
			} else {
				if strings.HasPrefix(p, "/") {
					path = path + p
				} else {
					path = path + "/" + p
				}
			}
		}
	}
	return path
}

func URLValueToXML(m url.Values) string {
	var xmlBuffer = &bytes.Buffer{}
	xmlBuffer.WriteString("<xml>")

	for key := range m {
		var value = m.Get(key)
		if key == "total_fee" || key == "refund_fee" || key == "execute_time" {
			xmlBuffer.WriteString("<" + key + ">" + value + "</" + key + ">")
		} else {
			xmlBuffer.WriteString("<" + key + "><![CDATA[" + value + "]]></" + key + ">")
		}
	}
	xmlBuffer.WriteString("</xml>")
	return xmlBuffer.String()
}

func SignMD5(param url.Values, key string) (sign string) {
	var pList = make([]string, 0, 0)
	for key := range param {
		var value = param.Get(key)
		if len(value) > 0 {
			pList = append(pList, key+"="+value)
		}
	}
	sort.Strings(pList)
	if key != "" {
		pList = append(pList, "key="+key)
	}

	var src = strings.Join(pList, "&")
	md5Ctx := md5.New()
	md5Ctx.Write([]byte(src))
	cipherStr := md5Ctx.Sum(nil)

	sign = strings.ToUpper(hex.EncodeToString(cipherStr))
	return sign
}

func VerifyResponseData(data []byte, key string) (ok bool, err error) {
	var param = make(XMLMap)
	err = xml.Unmarshal(data, &param)
	if err != nil {
		return false, err
	}

	return VerifyResponseValues(url.Values(param), key)
}

func VerifyResponseValues(param url.Values, key string) (bool, error) {
	// 处理错误信息
	var code = param.Get("return_code")
	if code != ReturnCodeSuccess {
		var msg = param.Get("return_msg")
		if msg == "" {
			msg = param.Get("retmsg")
		}
		return false, errors.New(msg)
	}

	code = param.Get("result_code")
	if code != ReturnCodeSuccess {
		var msg = param.Get("err_code_des")
		return false, errors.New(msg)
	}

	// 验证签名
	var sign = param.Get("sign")
	delete(param, "sign")
	if sign == "" {
		return false, errors.New("签名验证失败")
	}

	var sign2 = SignMD5(param, key)
	if sign == sign2 {
		return true, nil
	}
	return false, errors.New("签名验证失败")
}

func GetNonceStr() (nonceStr string) {
	chars := "abcdefghijklmnopqrstuvwxyz0123456789"
	var r = rand.New(rand.NewSource(time.Now().UnixNano()))
	for i := 0; i < 32; i++ {
		idx := r.Intn(len(chars) - 1)
		nonceStr += chars[idx : idx+1]
	}
	return nonceStr
}

func pkcs12ToPem(p12 []byte, password string) (cert tls.Certificate, err error) {
	blocks, err := pkcs12.ToPEM(p12, password)

	if err != nil {
		return cert, err
	}

	var pemData []byte
	for _, b := range blocks {
		pemData = append(pemData, pem.EncodeToMemory(b)...)
	}

	cert, err = tls.X509KeyPair(pemData, pemData)
	return cert, err
}