forked from pneymrl2f/nightingale
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
251 lines
7.3 KiB
251 lines
7.3 KiB
// Copyright (c) 2017 Uber Technologies, Inc.
|
|
|
|
// Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
// of this software and associated documentation files (the "Software"), to deal
|
|
// in the Software without restriction, including without limitation the rights
|
|
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
// copies of the Software, and to permit persons to whom the Software is
|
|
// furnished to do so, subject to the following conditions:
|
|
//
|
|
// The above copyright notice and this permission notice shall be included in
|
|
// all copies or substantial portions of the Software.
|
|
//
|
|
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
|
// THE SOFTWARE.
|
|
|
|
package tchannel
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"fmt"
|
|
"io"
|
|
"math"
|
|
"net"
|
|
"strconv"
|
|
"time"
|
|
|
|
"golang.org/x/net/context"
|
|
)
|
|
|
|
func (ch *Channel) outboundHandshake(ctx context.Context, c net.Conn, outboundHP string, events connectionEvents) (_ *Connection, err error) {
|
|
defer setInitDeadline(ctx, c)()
|
|
defer func() {
|
|
err = ch.initError(c, outbound, 1, err)
|
|
}()
|
|
|
|
msg := &initReq{initMessage: ch.getInitMessage(ctx, 1)}
|
|
if err := ch.writeMessage(c, msg); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
res := &initRes{}
|
|
id, err := ch.readMessage(c, res)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if id != msg.id {
|
|
return nil, NewSystemError(ErrCodeProtocol, "received initRes with invalid ID, wanted %v, got %v", msg.id, id)
|
|
}
|
|
|
|
if res.Version != CurrentProtocolVersion {
|
|
return nil, unsupportedProtocolVersion(res.Version)
|
|
}
|
|
|
|
remotePeer, remotePeerAddress, err := parseRemotePeer(res.initParams, c.RemoteAddr())
|
|
if err != nil {
|
|
return nil, NewWrappedSystemError(ErrCodeProtocol, err)
|
|
}
|
|
|
|
return ch.newConnection(c, 1 /* initialID */, outboundHP, remotePeer, remotePeerAddress, events), nil
|
|
}
|
|
|
|
func (ch *Channel) inboundHandshake(ctx context.Context, c net.Conn, events connectionEvents) (_ *Connection, err error) {
|
|
id := uint32(math.MaxUint32)
|
|
|
|
defer setInitDeadline(ctx, c)()
|
|
defer func() {
|
|
err = ch.initError(c, inbound, id, err)
|
|
}()
|
|
|
|
req := &initReq{}
|
|
id, err = ch.readMessage(c, req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if req.Version < CurrentProtocolVersion {
|
|
return nil, unsupportedProtocolVersion(req.Version)
|
|
}
|
|
|
|
remotePeer, remotePeerAddress, err := parseRemotePeer(req.initParams, c.RemoteAddr())
|
|
if err != nil {
|
|
return nil, NewWrappedSystemError(ErrCodeProtocol, err)
|
|
}
|
|
|
|
res := &initRes{initMessage: ch.getInitMessage(ctx, id)}
|
|
if err := ch.writeMessage(c, res); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return ch.newConnection(c, 0 /* initialID */, "" /* outboundHP */, remotePeer, remotePeerAddress, events), nil
|
|
}
|
|
|
|
func (ch *Channel) getInitParams() initParams {
|
|
localPeer := ch.PeerInfo()
|
|
return initParams{
|
|
InitParamHostPort: localPeer.HostPort,
|
|
InitParamProcessName: localPeer.ProcessName,
|
|
InitParamTChannelLanguage: localPeer.Version.Language,
|
|
InitParamTChannelLanguageVersion: localPeer.Version.LanguageVersion,
|
|
InitParamTChannelVersion: localPeer.Version.TChannelVersion,
|
|
}
|
|
}
|
|
|
|
func (ch *Channel) getInitMessage(ctx context.Context, id uint32) initMessage {
|
|
msg := initMessage{
|
|
id: id,
|
|
Version: CurrentProtocolVersion,
|
|
initParams: ch.getInitParams(),
|
|
}
|
|
if p := getTChannelParams(ctx); p != nil && p.hideListeningOnOutbound {
|
|
msg.initParams[InitParamHostPort] = ephemeralHostPort
|
|
}
|
|
|
|
return msg
|
|
}
|
|
|
|
func (ch *Channel) initError(c net.Conn, connDir connectionDirection, id uint32, err error) error {
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
ch.log.WithFields(LogFields{
|
|
{"connectionDirection", connDir},
|
|
{"localAddr", c.LocalAddr().String()},
|
|
{"remoteAddr", c.RemoteAddr().String()},
|
|
ErrField(err),
|
|
}...).Error("Failed during connection handshake.")
|
|
|
|
if ne, ok := err.(net.Error); ok && ne.Timeout() {
|
|
err = ErrTimeout
|
|
}
|
|
if err == io.EOF {
|
|
err = NewWrappedSystemError(ErrCodeNetwork, io.EOF)
|
|
}
|
|
ch.writeMessage(c, &errorMessage{
|
|
id: id,
|
|
errCode: GetSystemErrorCode(err),
|
|
message: err.Error(),
|
|
})
|
|
c.Close()
|
|
return err
|
|
}
|
|
|
|
func (ch *Channel) writeMessage(c net.Conn, msg message) error {
|
|
frame := ch.connectionOptions.FramePool.Get()
|
|
defer ch.connectionOptions.FramePool.Release(frame)
|
|
|
|
if err := frame.write(msg); err != nil {
|
|
return err
|
|
}
|
|
return frame.WriteOut(c)
|
|
}
|
|
|
|
func (ch *Channel) readMessage(c net.Conn, msg message) (uint32, error) {
|
|
frame := ch.connectionOptions.FramePool.Get()
|
|
defer ch.connectionOptions.FramePool.Release(frame)
|
|
|
|
if err := frame.ReadIn(c); err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
if frame.Header.messageType != msg.messageType() {
|
|
if frame.Header.messageType == messageTypeError {
|
|
return frame.Header.ID, readError(frame)
|
|
}
|
|
return frame.Header.ID, NewSystemError(ErrCodeProtocol, "expected message type %v, got %v", msg.messageType(), frame.Header.messageType)
|
|
}
|
|
|
|
return frame.Header.ID, frame.read(msg)
|
|
}
|
|
|
|
func parseRemotePeer(p initParams, remoteAddr net.Addr) (PeerInfo, peerAddressComponents, error) {
|
|
var (
|
|
remotePeer PeerInfo
|
|
remotePeerAddress peerAddressComponents
|
|
ok bool
|
|
)
|
|
|
|
if remotePeer.HostPort, ok = p[InitParamHostPort]; !ok {
|
|
return remotePeer, remotePeerAddress, fmt.Errorf("header %v is required", InitParamHostPort)
|
|
}
|
|
if remotePeer.ProcessName, ok = p[InitParamProcessName]; !ok {
|
|
return remotePeer, remotePeerAddress, fmt.Errorf("header %v is required", InitParamProcessName)
|
|
}
|
|
|
|
// If the remote host:port is ephemeral, use the socket address as the
|
|
// host:port and set IsEphemeral to true.
|
|
if isEphemeralHostPort(remotePeer.HostPort) {
|
|
remotePeer.HostPort = remoteAddr.String()
|
|
remotePeer.IsEphemeral = true
|
|
}
|
|
|
|
remotePeer.Version.Language = p[InitParamTChannelLanguage]
|
|
remotePeer.Version.LanguageVersion = p[InitParamTChannelLanguageVersion]
|
|
remotePeer.Version.TChannelVersion = p[InitParamTChannelVersion]
|
|
|
|
address := remotePeer.HostPort
|
|
if sHost, sPort, err := net.SplitHostPort(address); err == nil {
|
|
address = sHost
|
|
if p, err := strconv.ParseUint(sPort, 10, 16); err == nil {
|
|
remotePeerAddress.port = uint16(p)
|
|
}
|
|
}
|
|
if address == "localhost" {
|
|
remotePeerAddress.ipv4 = 127<<24 | 1
|
|
} else if ip := net.ParseIP(address); ip != nil {
|
|
if ip4 := ip.To4(); ip4 != nil {
|
|
remotePeerAddress.ipv4 = binary.BigEndian.Uint32(ip4)
|
|
} else {
|
|
remotePeerAddress.ipv6 = address
|
|
}
|
|
} else {
|
|
remotePeerAddress.hostname = address
|
|
}
|
|
|
|
return remotePeer, remotePeerAddress, nil
|
|
}
|
|
|
|
func setInitDeadline(ctx context.Context, c net.Conn) func() {
|
|
deadline, ok := ctx.Deadline()
|
|
if !ok {
|
|
deadline = time.Now().Add(5 * time.Second)
|
|
}
|
|
|
|
c.SetDeadline(deadline)
|
|
return func() {
|
|
c.SetDeadline(time.Time{})
|
|
}
|
|
}
|
|
|
|
func readError(frame *Frame) error {
|
|
errMsg := &errorMessage{
|
|
id: frame.Header.ID,
|
|
}
|
|
if err := frame.read(errMsg); err != nil {
|
|
return err
|
|
}
|
|
|
|
return errMsg.AsSystemError()
|
|
}
|
|
|
|
func unsupportedProtocolVersion(got uint16) error {
|
|
return NewSystemError(ErrCodeProtocol, "unsupported protocol version %d from peer, expected %v", got, CurrentProtocolVersion)
|
|
}
|