summary refs log tree commit diff
path: root/vendor/github.com/chzyer/readline/remote.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/chzyer/readline/remote.go')
-rw-r--r--vendor/github.com/chzyer/readline/remote.go475
1 files changed, 475 insertions, 0 deletions
diff --git a/vendor/github.com/chzyer/readline/remote.go b/vendor/github.com/chzyer/readline/remote.go
new file mode 100644
index 0000000..74dbf56
--- /dev/null
+++ b/vendor/github.com/chzyer/readline/remote.go
@@ -0,0 +1,475 @@
+package readline
+
+import (
+	"bufio"
+	"bytes"
+	"encoding/binary"
+	"fmt"
+	"io"
+	"net"
+	"os"
+	"sync"
+	"sync/atomic"
+)
+
+type MsgType int16
+
+const (
+	T_DATA = MsgType(iota)
+	T_WIDTH
+	T_WIDTH_REPORT
+	T_ISTTY_REPORT
+	T_RAW
+	T_ERAW // exit raw
+	T_EOF
+)
+
+type RemoteSvr struct {
+	eof           int32
+	closed        int32
+	width         int32
+	reciveChan    chan struct{}
+	writeChan     chan *writeCtx
+	conn          net.Conn
+	isTerminal    bool
+	funcWidthChan func()
+	stopChan      chan struct{}
+
+	dataBufM sync.Mutex
+	dataBuf  bytes.Buffer
+}
+
+type writeReply struct {
+	n   int
+	err error
+}
+
+type writeCtx struct {
+	msg   *Message
+	reply chan *writeReply
+}
+
+func newWriteCtx(msg *Message) *writeCtx {
+	return &writeCtx{
+		msg:   msg,
+		reply: make(chan *writeReply),
+	}
+}
+
+func NewRemoteSvr(conn net.Conn) (*RemoteSvr, error) {
+	rs := &RemoteSvr{
+		width:      -1,
+		conn:       conn,
+		writeChan:  make(chan *writeCtx),
+		reciveChan: make(chan struct{}),
+		stopChan:   make(chan struct{}),
+	}
+	buf := bufio.NewReader(rs.conn)
+
+	if err := rs.init(buf); err != nil {
+		return nil, err
+	}
+
+	go rs.readLoop(buf)
+	go rs.writeLoop()
+	return rs, nil
+}
+
+func (r *RemoteSvr) init(buf *bufio.Reader) error {
+	m, err := ReadMessage(buf)
+	if err != nil {
+		return err
+	}
+	// receive isTerminal
+	if m.Type != T_ISTTY_REPORT {
+		return fmt.Errorf("unexpected init message")
+	}
+	r.GotIsTerminal(m.Data)
+
+	// receive width
+	m, err = ReadMessage(buf)
+	if err != nil {
+		return err
+	}
+	if m.Type != T_WIDTH_REPORT {
+		return fmt.Errorf("unexpected init message")
+	}
+	r.GotReportWidth(m.Data)
+
+	return nil
+}
+
+func (r *RemoteSvr) HandleConfig(cfg *Config) {
+	cfg.Stderr = r
+	cfg.Stdout = r
+	cfg.Stdin = r
+	cfg.FuncExitRaw = r.ExitRawMode
+	cfg.FuncIsTerminal = r.IsTerminal
+	cfg.FuncMakeRaw = r.EnterRawMode
+	cfg.FuncExitRaw = r.ExitRawMode
+	cfg.FuncGetWidth = r.GetWidth
+	cfg.FuncOnWidthChanged = func(f func()) {
+		r.funcWidthChan = f
+	}
+}
+
+func (r *RemoteSvr) IsTerminal() bool {
+	return r.isTerminal
+}
+
+func (r *RemoteSvr) checkEOF() error {
+	if atomic.LoadInt32(&r.eof) == 1 {
+		return io.EOF
+	}
+	return nil
+}
+
+func (r *RemoteSvr) Read(b []byte) (int, error) {
+	r.dataBufM.Lock()
+	n, err := r.dataBuf.Read(b)
+	r.dataBufM.Unlock()
+	if n == 0 {
+		if err := r.checkEOF(); err != nil {
+			return 0, err
+		}
+	}
+
+	if n == 0 && err == io.EOF {
+		<-r.reciveChan
+		r.dataBufM.Lock()
+		n, err = r.dataBuf.Read(b)
+		r.dataBufM.Unlock()
+	}
+	if n == 0 {
+		if err := r.checkEOF(); err != nil {
+			return 0, err
+		}
+	}
+
+	return n, err
+}
+
+func (r *RemoteSvr) writeMsg(m *Message) error {
+	ctx := newWriteCtx(m)
+	r.writeChan <- ctx
+	reply := <-ctx.reply
+	return reply.err
+}
+
+func (r *RemoteSvr) Write(b []byte) (int, error) {
+	ctx := newWriteCtx(NewMessage(T_DATA, b))
+	r.writeChan <- ctx
+	reply := <-ctx.reply
+	return reply.n, reply.err
+}
+
+func (r *RemoteSvr) EnterRawMode() error {
+	return r.writeMsg(NewMessage(T_RAW, nil))
+}
+
+func (r *RemoteSvr) ExitRawMode() error {
+	return r.writeMsg(NewMessage(T_ERAW, nil))
+}
+
+func (r *RemoteSvr) writeLoop() {
+	defer r.Close()
+
+loop:
+	for {
+		select {
+		case ctx, ok := <-r.writeChan:
+			if !ok {
+				break
+			}
+			n, err := ctx.msg.WriteTo(r.conn)
+			ctx.reply <- &writeReply{n, err}
+		case <-r.stopChan:
+			break loop
+		}
+	}
+}
+
+func (r *RemoteSvr) Close() error {
+	if atomic.CompareAndSwapInt32(&r.closed, 0, 1) {
+		close(r.stopChan)
+		r.conn.Close()
+	}
+	return nil
+}
+
+func (r *RemoteSvr) readLoop(buf *bufio.Reader) {
+	defer r.Close()
+	for {
+		m, err := ReadMessage(buf)
+		if err != nil {
+			break
+		}
+		switch m.Type {
+		case T_EOF:
+			atomic.StoreInt32(&r.eof, 1)
+			select {
+			case r.reciveChan <- struct{}{}:
+			default:
+			}
+		case T_DATA:
+			r.dataBufM.Lock()
+			r.dataBuf.Write(m.Data)
+			r.dataBufM.Unlock()
+			select {
+			case r.reciveChan <- struct{}{}:
+			default:
+			}
+		case T_WIDTH_REPORT:
+			r.GotReportWidth(m.Data)
+		case T_ISTTY_REPORT:
+			r.GotIsTerminal(m.Data)
+		}
+	}
+}
+
+func (r *RemoteSvr) GotIsTerminal(data []byte) {
+	if binary.BigEndian.Uint16(data) == 0 {
+		r.isTerminal = false
+	} else {
+		r.isTerminal = true
+	}
+}
+
+func (r *RemoteSvr) GotReportWidth(data []byte) {
+	atomic.StoreInt32(&r.width, int32(binary.BigEndian.Uint16(data)))
+	if r.funcWidthChan != nil {
+		r.funcWidthChan()
+	}
+}
+
+func (r *RemoteSvr) GetWidth() int {
+	return int(atomic.LoadInt32(&r.width))
+}
+
+// -----------------------------------------------------------------------------
+
+type Message struct {
+	Type MsgType
+	Data []byte
+}
+
+func ReadMessage(r io.Reader) (*Message, error) {
+	m := new(Message)
+	var length int32
+	if err := binary.Read(r, binary.BigEndian, &length); err != nil {
+		return nil, err
+	}
+	if err := binary.Read(r, binary.BigEndian, &m.Type); err != nil {
+		return nil, err
+	}
+	m.Data = make([]byte, int(length)-2)
+	if _, err := io.ReadFull(r, m.Data); err != nil {
+		return nil, err
+	}
+	return m, nil
+}
+
+func NewMessage(t MsgType, data []byte) *Message {
+	return &Message{t, data}
+}
+
+func (m *Message) WriteTo(w io.Writer) (int, error) {
+	buf := bytes.NewBuffer(make([]byte, 0, len(m.Data)+2+4))
+	binary.Write(buf, binary.BigEndian, int32(len(m.Data)+2))
+	binary.Write(buf, binary.BigEndian, m.Type)
+	buf.Write(m.Data)
+	n, err := buf.WriteTo(w)
+	return int(n), err
+}
+
+// -----------------------------------------------------------------------------
+
+type RemoteCli struct {
+	conn        net.Conn
+	raw         RawMode
+	receiveChan chan struct{}
+	inited      int32
+	isTerminal  *bool
+
+	data  bytes.Buffer
+	dataM sync.Mutex
+}
+
+func NewRemoteCli(conn net.Conn) (*RemoteCli, error) {
+	r := &RemoteCli{
+		conn:        conn,
+		receiveChan: make(chan struct{}),
+	}
+	return r, nil
+}
+
+func (r *RemoteCli) MarkIsTerminal(is bool) {
+	r.isTerminal = &is
+}
+
+func (r *RemoteCli) init() error {
+	if !atomic.CompareAndSwapInt32(&r.inited, 0, 1) {
+		return nil
+	}
+
+	if err := r.reportIsTerminal(); err != nil {
+		return err
+	}
+
+	if err := r.reportWidth(); err != nil {
+		return err
+	}
+
+	// register sig for width changed
+	DefaultOnWidthChanged(func() {
+		r.reportWidth()
+	})
+	return nil
+}
+
+func (r *RemoteCli) writeMsg(m *Message) error {
+	r.dataM.Lock()
+	_, err := m.WriteTo(r.conn)
+	r.dataM.Unlock()
+	return err
+}
+
+func (r *RemoteCli) Write(b []byte) (int, error) {
+	m := NewMessage(T_DATA, b)
+	r.dataM.Lock()
+	_, err := m.WriteTo(r.conn)
+	r.dataM.Unlock()
+	return len(b), err
+}
+
+func (r *RemoteCli) reportWidth() error {
+	screenWidth := GetScreenWidth()
+	data := make([]byte, 2)
+	binary.BigEndian.PutUint16(data, uint16(screenWidth))
+	msg := NewMessage(T_WIDTH_REPORT, data)
+
+	if err := r.writeMsg(msg); err != nil {
+		return err
+	}
+	return nil
+}
+
+func (r *RemoteCli) reportIsTerminal() error {
+	var isTerminal bool
+	if r.isTerminal != nil {
+		isTerminal = *r.isTerminal
+	} else {
+		isTerminal = DefaultIsTerminal()
+	}
+	data := make([]byte, 2)
+	if isTerminal {
+		binary.BigEndian.PutUint16(data, 1)
+	} else {
+		binary.BigEndian.PutUint16(data, 0)
+	}
+	msg := NewMessage(T_ISTTY_REPORT, data)
+	if err := r.writeMsg(msg); err != nil {
+		return err
+	}
+	return nil
+}
+
+func (r *RemoteCli) readLoop() {
+	buf := bufio.NewReader(r.conn)
+	for {
+		msg, err := ReadMessage(buf)
+		if err != nil {
+			break
+		}
+		switch msg.Type {
+		case T_ERAW:
+			r.raw.Exit()
+		case T_RAW:
+			r.raw.Enter()
+		case T_DATA:
+			os.Stdout.Write(msg.Data)
+		}
+	}
+}
+
+func (r *RemoteCli) ServeBy(source io.Reader) error {
+	if err := r.init(); err != nil {
+		return err
+	}
+
+	go func() {
+		defer r.Close()
+		for {
+			n, _ := io.Copy(r, source)
+			if n == 0 {
+				break
+			}
+		}
+	}()
+	defer r.raw.Exit()
+	r.readLoop()
+	return nil
+}
+
+func (r *RemoteCli) Close() {
+	r.writeMsg(NewMessage(T_EOF, nil))
+}
+
+func (r *RemoteCli) Serve() error {
+	return r.ServeBy(os.Stdin)
+}
+
+func ListenRemote(n, addr string, cfg *Config, h func(*Instance), onListen ...func(net.Listener) error) error {
+	ln, err := net.Listen(n, addr)
+	if err != nil {
+		return err
+	}
+	if len(onListen) > 0 {
+		if err := onListen[0](ln); err != nil {
+			return err
+		}
+	}
+	for {
+		conn, err := ln.Accept()
+		if err != nil {
+			break
+		}
+		go func() {
+			defer conn.Close()
+			rl, err := HandleConn(*cfg, conn)
+			if err != nil {
+				return
+			}
+			h(rl)
+		}()
+	}
+	return nil
+}
+
+func HandleConn(cfg Config, conn net.Conn) (*Instance, error) {
+	r, err := NewRemoteSvr(conn)
+	if err != nil {
+		return nil, err
+	}
+	r.HandleConfig(&cfg)
+
+	rl, err := NewEx(&cfg)
+	if err != nil {
+		return nil, err
+	}
+	return rl, nil
+}
+
+func DialRemote(n, addr string) error {
+	conn, err := net.Dial(n, addr)
+	if err != nil {
+		return err
+	}
+	defer conn.Close()
+
+	cli, err := NewRemoteCli(conn)
+	if err != nil {
+		return err
+	}
+	return cli.Serve()
+}