1 /*
2  * MsgTrans - Message Transport Framework for DLang. Based on TCP, WebSocket, UDP transmission protocol.
3  *
4  * Copyright (C) 2019 HuntLabs
5  *
6  * Website: https://www.msgtrans.org
7  *
8  * Licensed under the Apache-2.0 License.
9  *
10  */
11 
12 module msgtrans.channel.tcp.TcpClientChannel;
13 
14 import msgtrans.DefaultSessionManager;
15 import msgtrans.executor;
16 import msgtrans.channel.ClientChannel;
17 import msgtrans.channel.TransportSession;
18 import msgtrans.channel.tcp.TcpCodec;
19 import msgtrans.channel.tcp.TcpTransportSession;
20 import msgtrans.MessageBuffer;
21 import msgtrans.MessageHandler;
22 import msgtrans.MessageTransport;
23 import msgtrans.Packet;
24 import msgtrans.TransportContext;
25 import msgtrans.ee2e.message.MsgDefine;
26 import msgtrans.ee2e.crypto;
27 import msgtrans.ee2e.common;
28 import msgtrans.MessageTransportClient;
29 import hunt.Exceptions;
30 import hunt.io.channel.Common;
31 // import hunt.concurrency.FuturePromise;
32 import hunt.logging.ConsoleLogger;
33 import hunt.net;
34 
35 import google.protobuf;
36 
37 import std.array;
38 import std.base64;
39 import std.format;
40 
41 import core.sync.condition;
42 import core.sync.mutex;
43 
44 
45 /**
46  *
47  */
48 class TcpClientChannel : ClientChannel {
49     private string _host;
50     private ushort _port;
51 
52     private MessageTransport _messageTransport;
53     private NetClient _client;
54     private NetClientOptions _options;
55     private Connection _connection;
56     private CloseHandler _closeHandler;
57     private Mutex _connectLocker;
58     private Condition _connectCondition;
59 
60     this(string host, ushort port) {
61         _host = host;
62         _port = port;
63 
64         _options = new NetClientOptions();
65         _options.setIdleTimeout(15.seconds);
66         _options.setConnectTimeout(5.seconds);
67 
68         _connectLocker = new Mutex();
69         _connectCondition = new Condition(_connectLocker);
70     }
71 
72     void set(MessageTransport transport) {
73         _messageTransport = transport;
74     }
75 
76     void onClose(CloseHandler handler)
77     {
78       _closeHandler = handler;
79     }
80 
81     void keyExchangeInitiate()
82     {
83         KeyExchangeRequest keyExchangeRes = new KeyExchangeRequest;
84         KeyInfo keyInfo = new KeyInfo;
85 
86         keyInfo.salt_32bytes = Base64.encode(MessageTransportClient.client_key.salt);
87         keyInfo.ec_public_key_65bytes = Base64.encode(MessageTransportClient.client_key.ec_pub_key); //Base64.encode(MessageTransportClient.client_key.ec_pub_key);
88         //logInfo("%s",MessageTransportClient.client_key.ec_pub_key);
89 
90         keyExchangeRes.key_info = keyInfo;
91         keyExchangeRes.key_exchange_type = KeyExchangeType.KEY_EXCHANGE_INITIATE;
92 
93         logInfo("salt :%s",keyInfo.salt_32bytes);
94 
95         logInfo("%s",keyInfo.ec_public_key_65bytes);
96         send(new MessageBuffer(MESSAGE.INITIATE,keyExchangeRes.toProtobuf.array));
97     }
98 
99     private void initialize() {
100 
101         _client = NetUtil.createNetClient(_options);
102 
103         _client.setCodec(new TcpCodec());
104 
105         _client.setHandler(new class NetConnectionHandler {
106 
107             override void connectionOpened(Connection connection) {
108                 version(HUNT_DEBUG) infof("Connection created: %s", connection.getRemoteAddress());
109                 _connection = connection;
110 
111                 _connectLocker.lock();
112                 scope(exit) {
113                     _connectLocker.unlock();
114                 }
115 
116                 _connectCondition.notifyAll();
117                 if (MessageTransportClient.isEE2E)
118                 {
119                   keyExchangeInitiate();
120                 }
121 
122             }
123 
124             override void connectionClosed(Connection connection) {
125                 version(HUNT_DEBUG) infof("Connection closed: %s", connection.getRemoteAddress());
126                 _connection = null;
127                 if(_closeHandler !is null)
128                 {
129                   TransportContext t;
130                   _closeHandler(t);
131                 }
132                 // client.close();
133             }
134 
135             override DataHandleStatus messageReceived(Connection connection, Object message) {
136                 MessageBuffer buffer = cast(MessageBuffer)message;
137                 if(buffer is null) {
138                     warningf("expected type: MessageBuffer, message type: %s", typeid(message).name);
139                 } else {
140                     dispatchMessage(connection, buffer);
141                 }
142 
143                 return DataHandleStatus.Done;
144             }
145 
146             override void exceptionCaught(Connection connection, Throwable t) {
147                 debug warning(t.msg);
148             }
149 
150             override void failedOpeningConnection(int connectionId, Throwable t) {
151                 debug warning(t.msg);
152                 // _client.close();
153                 _connectLocker.lock();
154                 scope(exit) {
155                     _connectLocker.unlock();
156                 }                
157                 _connectCondition.notifyAll();
158             }
159 
160             override void failedAcceptingConnection(int connectionId, Throwable t) {
161                 debug warning(t.msg);
162             }
163         });
164     }
165 
166 
167     private void dispatchMessage(Connection connection, MessageBuffer message ) {
168         version(HUNT_DEBUG) {
169             string str = format("data received: %s", message.toString());
170             tracef(str);
171         }
172 
173         // tx: 00 00 27 11 00 00 00 05 00 00 00 00 00 00 00 00 57 6F 72 6C 64
174         // rx: 00 00 4E 21 00 00 00 0B 00 00 00 00 00 00 00 00 48 65 6C 6C 6F 20 57 6F 72 6C 64
175 
176         uint messageId = message.id;
177         if (messageId == MESSAGE.INITIATE || messageId == MESSAGE.FINALIZE)
178         {
179             keyExchangeRequest(message,connection);
180             return;
181         }
182 
183         MessageHandler handler = _messageTransport.getMessageHandler(messageId);
184         if(handler is null) {
185             dispatchForExecutor(connection, messageId, message);
186         } else {
187             TransportContext context = getContext(connection);
188 
189             if (MessageTransportClient.isEE2E)
190             {
191                 version(HUNT_DEBUG) logInfo("......................");
192                 message = common.encrypted_decode(message,MessageTransportClient.server_key, true);
193             }            
194             handler(context, message);
195         }
196 
197     }
198 
199     private void dispatchForExecutor(Connection connection, uint messageId, MessageBuffer message) {
200 
201         ExecutorInfo executorInfo = _messageTransport.getExecutor(messageId);
202         if(executorInfo == ExecutorInfo.init) {
203             warning("No Executor found for id: ", messageId);
204         } else {
205             TransportContext context = getContext(connection);
206 
207             if (MessageTransportClient.isEE2E)
208             {
209                 logInfo("......................");
210                 message = common.encrypted_decode(message,MessageTransportClient.server_key, true);
211             }            
212             executorInfo.execute(context, message);
213         }
214     }
215 
216     private TransportContext getContext(Connection connection) {
217 
218         enum string ChannelSession = "ChannelSession";
219         TcpTransportSession session = cast(TcpTransportSession)connection.getAttribute(ChannelSession);
220         if(session is null ){
221             session = new TcpTransportSession(nextClientSessionId(), connection);
222             connection.setAttribute(ChannelSession, session);
223         }
224 
225         TransportContext context = TransportContext(null, session);
226 
227 
228         return context;
229     }
230 
231     private void keyExchangeRequest(MessageBuffer message, Connection connection)
232     {
233         switch(message.id)
234         {
235             case MESSAGE.INITIATE :
236             {
237                 KeyExchangeRequest keyExchangeRes = new KeyExchangeRequest;
238                 message.data.fromProtobuf!KeyExchangeRequest(keyExchangeRes);
239 
240                 MessageTransportClient.server_key.ec_pub_key = Base64.decode(keyExchangeRes.key_info.ec_public_key_65bytes);
241                 MessageTransportClient.server_key.salt = Base64.decode(keyExchangeRes.key_info.salt_32bytes);
242 
243                 //logInfo("server pub : %s" , MessageTransportClient.server_key.ec_pub_key);
244                 //logInfo("service salt : %s", MessageTransportClient.server_key.salt);
245 
246                 if (common.keyCalculate(MessageTransportClient.client_key,MessageTransportClient.server_key))
247                 {
248                     send(new MessageBuffer(cast(uint)MESSAGE.FINALIZE, cast(ubyte[])[]));
249                 }else
250                 {
251                     logError("keyCalculate error");
252                 }
253                 break;
254             }
255             case MESSAGE.FINALIZE :
256             {
257                  logInfo("======================Key exchange completed======================");
258                  break;
259             }
260             default : break;
261         }
262 
263     }
264 
265     void connect() {
266 
267         //if(_client !is null) {
268         //    return;
269         //}
270 
271         initialize();
272 
273         _client.connect(_host, _port);
274 
275         if(_client.isConnected())
276             return;
277 
278         _connectLocker.lock();
279         scope(exit) {
280             _connectLocker.unlock();
281         }
282 
283         Duration connectTimeout = _options.getConnectTimeout();
284         if(connectTimeout.isNegative()) {
285             version (HUNT_DEBUG) infof("connecting...");
286             _connectCondition.wait();
287         } else {
288             version (HUNT_DEBUG) infof("waiting for the connection in %s ...", connectTimeout);
289             bool r = _connectCondition.wait(connectTimeout);
290             if(r) {
291                 if(!_client.isConnected()) {
292                     string msg = format("Failed to connect to %s:%d", _host, _port);
293                     warning(msg);
294                     _client.close();
295                     throw new IOException(msg);
296                 }
297 
298             } else {
299                 warningf("connect timeout in %s", connectTimeout);
300                 _client.close();
301                 throw new TimeoutException();
302             }
303         }
304     }
305 
306     bool isConnected() {
307         return _client !is null && _client.isConnected();
308     }
309 
310     void send(MessageBuffer message) {
311         if(!isConnected()) {
312             throw new IOException("Connection broken!");
313         }
314 
315         if (MessageTransportClient.isEE2E && (message.id != MESSAGE.INITIATE  && message.id != MESSAGE.FINALIZE))
316         {
317             message = common.encrypted_encode(message,MessageTransportClient.client_key,MessageTransportClient.server_key);
318         }
319         ubyte[][] buffers = Packet.encode(message);
320         foreach(ubyte[] data; buffers) {
321             _connection.write(data);
322         }
323     }
324 
325     void close() {
326         if(_client !is null) {
327             _client.close();
328         }
329     }
330 }
331