1 /*
2  * MsgTrans - Message Transport Framework for DLang. Based on TCP, WebSocket, UDP transmission protocol.
3  *
4  * Copyright (C) 2019 HuntLabs
5  *
6  * Website: https://www.msgtrans.org
7  *
8  * Licensed under the Apache-2.0 License.
9  *
10  */
11 
12 module msgtrans.channel.tcp.TcpClientChannel;
13 
14 import msgtrans.DefaultSessionManager;
15 import msgtrans.executor;
16 import msgtrans.channel.ClientChannel;
17 import msgtrans.channel.TransportSession;
18 import msgtrans.channel.tcp.TcpCodec;
19 import msgtrans.channel.tcp.TcpTransportSession;
20 import msgtrans.MessageBuffer;
21 import msgtrans.MessageTransport;
22 import msgtrans.Packet;
23 import msgtrans.TransportContext;
24 import google.protobuf;
25 import std.array;
26 import msgtrans.ee2e.message.MsgDefine;
27 import msgtrans.ee2e.crypto;
28 import msgtrans.ee2e.common;
29 import msgtrans.MessageTransportClient;
30 import hunt.Exceptions;
31 // import hunt.concurrency.FuturePromise;
32 import hunt.logging.ConsoleLogger;
33 import hunt.net;
34 import std.base64;
35 
36 
37 import core.sync.condition;
38 import core.sync.mutex;
39 
40 import std.format;
41 
42 /**
43  *
44  */
45 class TcpClientChannel : ClientChannel {
46     private string _host;
47     private ushort _port;
48 
49     private MessageTransport _messageTransport;
50     private NetClient _client;
51     private NetClientOptions _options;
52     private Connection _connection;
53     private CloseHandler _closeHandler;
54     private Mutex _connectLocker;
55     private Condition _connectCondition;
56 
57     this(string host, ushort port) {
58         _host = host;
59         _port = port;
60 
61         _options = new NetClientOptions();
62         _options.setIdleTimeout(15.seconds);
63         _options.setConnectTimeout(5.seconds);
64 
65         _connectLocker = new Mutex();
66         _connectCondition = new Condition(_connectLocker);
67     }
68 
69     void set(MessageTransport transport) {
70         _messageTransport = transport;
71     }
72 
73     void onClose(CloseHandler handler)
74     {
75       _closeHandler = handler;
76     }
77 
78     void keyExchangeInitiate()
79     {
80         KeyExchangeRequest keyExchangeRes = new KeyExchangeRequest;
81         KeyInfo keyInfo = new KeyInfo;
82 
83         keyInfo.salt_32bytes = Base64.encode(MessageTransportClient.client_key.salt);
84         keyInfo.ec_public_key_65bytes = Base64.encode(MessageTransportClient.client_key.ec_pub_key); //Base64.encode(MessageTransportClient.client_key.ec_pub_key);
85         //logInfo("%s",MessageTransportClient.client_key.ec_pub_key);
86 
87         keyExchangeRes.key_info = keyInfo;
88         keyExchangeRes.key_exchange_type = KeyExchangeType.KEY_EXCHANGE_INITIATE;
89 
90         logInfo("salt :%s",keyInfo.salt_32bytes);
91 
92         logInfo("%s",keyInfo.ec_public_key_65bytes);
93         send(new MessageBuffer(MESSAGE.INITIATE,keyExchangeRes.toProtobuf.array));
94     }
95 
96     private void initialize() {
97 
98         _client = NetUtil.createNetClient(_options);
99 
100         _client.setCodec(new TcpCodec());
101 
102         _client.setHandler(new class NetConnectionHandler {
103 
104             override void connectionOpened(Connection connection) {
105                 version(HUNT_DEBUG) infof("Connection created: %s", connection.getRemoteAddress());
106                 _connection = connection;
107                 _connectCondition.notifyAll();
108                 if (MessageTransportClient.isEE2E)
109                 {
110                   keyExchangeInitiate();
111                 }
112 
113             }
114 
115             override void connectionClosed(Connection connection) {
116                 version(HUNT_DEBUG) infof("Connection closed: %s", connection.getRemoteAddress());
117                 _connection = null;
118                 if(_closeHandler !is null)
119                 {
120                   TransportContext t;
121                   _closeHandler(t);
122                 }
123                 // client.close();
124             }
125 
126             override void messageReceived(Connection connection, Object message) {
127                 MessageBuffer buffer = cast(MessageBuffer)message;
128                 if(buffer is null) {
129                     warningf("expected type: MessageBuffer, message type: %s", typeid(message).name);
130                 } else {
131                     dispatchMessage(connection, buffer);
132                 }
133             }
134 
135             override void exceptionCaught(Connection connection, Throwable t) {
136                 debug warning(t.msg);
137             }
138 
139             override void failedOpeningConnection(int connectionId, Throwable t) {
140                 debug warning(t.msg);
141                 // _client.close();
142                 _connectCondition.notifyAll();
143             }
144 
145             override void failedAcceptingConnection(int connectionId, Throwable t) {
146                 debug warning(t.msg);
147             }
148         });
149     }
150 
151 
152     private void dispatchMessage(Connection connection, MessageBuffer message ) {
153         version(HUNT_DEBUG) {
154             string str = format("data received: %s", message.toString());
155             tracef(str);
156         }
157 
158         // tx: 00 00 27 11 00 00 00 05 00 00 00 00 00 00 00 00 57 6F 72 6C 64
159         // rx: 00 00 4E 21 00 00 00 0B 00 00 00 00 00 00 00 00 48 65 6C 6C 6F 20 57 6F 72 6C 64
160 
161         uint messageId = message.id;
162         if (messageId == MESSAGE.INITIATE || messageId == MESSAGE.FINALIZE)
163         {
164             keyExchangeRequest(message,connection);
165             return;
166         }
167 
168         ExecutorInfo executorInfo = _messageTransport.getExecutor(messageId);
169         if(executorInfo == ExecutorInfo.init) {
170             warning("No Executor found for id: ", messageId);
171         } else {
172             enum string ChannelSession = "ChannelSession";
173             TcpTransportSession session = cast(TcpTransportSession)connection.getAttribute(ChannelSession);
174             if(session is null ){
175                 session = new TcpTransportSession(nextClientSessionId(), connection);
176                 connection.setAttribute(ChannelSession, session);
177             }
178 
179             TransportContext context = TransportContext(null, session);
180 
181             if (MessageTransportClient.isEE2E)
182             {
183                 logInfo("......................");
184                 message = common.encrypted_decode(message,MessageTransportClient.server_key, true);
185             }
186 
187             executorInfo.execute(context, message);
188         }
189     }
190 
191     private void keyExchangeRequest(MessageBuffer message, Connection connection)
192     {
193         switch(message.id)
194         {
195             case MESSAGE.INITIATE :
196             {
197                 KeyExchangeRequest keyExchangeRes = new KeyExchangeRequest;
198                 message.data.fromProtobuf!KeyExchangeRequest(keyExchangeRes);
199 
200                 MessageTransportClient.server_key.ec_pub_key = Base64.decode(keyExchangeRes.key_info.ec_public_key_65bytes);
201                 MessageTransportClient.server_key.salt = Base64.decode(keyExchangeRes.key_info.salt_32bytes);
202 
203                 //logInfo("server pub : %s" , MessageTransportClient.server_key.ec_pub_key);
204                 //logInfo("service salt : %s", MessageTransportClient.server_key.salt);
205 
206                 if (common.keyCalculate(MessageTransportClient.client_key,MessageTransportClient.server_key))
207                 {
208                     send(new MessageBuffer(MESSAGE.FINALIZE,[]));
209                 }else
210                 {
211                     logError("keyCalculate error");
212                 }
213                 break;
214             }
215             case MESSAGE.FINALIZE :
216             {
217                  logInfo("======================Key exchange completed======================");
218                  break;
219             }
220             default : break;
221         }
222 
223     }
224 
225     void connect() {
226         _connectLocker.lock();
227         scope(exit) {
228             _connectLocker.unlock();
229         }
230 
231         //if(_client !is null) {
232         //    return;
233         //}
234 
235         initialize();
236 
237         _client.connect(_host, _port);
238 
239         if(_client.isConnected())
240             return;
241 
242         Duration connectTimeout = _options.getConnectTimeout();
243         if(connectTimeout.isNegative()) {
244             version (HUNT_DEBUG) infof("connecting...");
245             _connectCondition.wait();
246         } else {
247             version (HUNT_DEBUG) infof("waiting for the connection in %s ...", connectTimeout);
248             bool r = _connectCondition.wait(connectTimeout);
249             if(r) {
250                 if(!_client.isConnected()) {
251                     string msg = format("Failed to connect to %s:%d", _host, _port);
252                     warning(msg);
253                     _client.close();
254                     throw new IOException(msg);
255                 }
256 
257             } else {
258                 warningf("connect timeout in %s", connectTimeout);
259                 _client.close();
260                 throw new TimeoutException();
261             }
262         }
263     }
264 
265     bool isConnected() {
266         return _client !is null && _client.isConnected();
267     }
268 
269     void send(MessageBuffer message) {
270         if(!isConnected()) {
271             throw new IOException("Connection broken!");
272         }
273         if (MessageTransportClient.isEE2E && (message.id != MESSAGE.INITIATE  && message.id != MESSAGE.FINALIZE))
274         {
275             message = common.encrypted_encode(message,MessageTransportClient.client_key,MessageTransportClient.server_key);
276         }
277         ubyte[][] buffers = Packet.encode(message);
278         foreach(ubyte[] data; buffers) {
279             _connection.write(data);
280         }
281     }
282 
283     void close() {
284         if(_client !is null) {
285             _client.close();
286         }
287     }
288 }
289